题目链接

题目描述

给定一个长为 $n$ 的数列,求所有 子数组 的 MEX 的 MEX

$1 \le n \le 10^5$

$1 \le a_i \le n$

简要做法

做这道题之前可以先去看看 [SDOI2009] HH的项链

枚举一个数 $x$,问题转化为判定 $x$ 是否是某个子数组的 MEX

如果小于 $x$ 的数最后一次出现的位置都在 $x$ 的位置和 $x$ 上一次出现的位置 之间,那么这个数就有了

参考代码

#include <stdio.h>
#include <algorithm>
#include <memory.h>

const int N = 1e5 + 5;
const int INF = 2147483647;

int n, a[N], vis[N], last[N];

int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while ('0' > ch or ch > '9')
        f = ch == '-' ? -1 : 1, ch = getchar();
    while ('0' <= ch and ch <= '9')
        x = x * 10 + ch - 48, ch = getchar();
    return x * f;
}

struct Segment_Tree
{
#define ls (p << 1)
#define rs (p << 1 | 1)
#define mid ((l + r) >> 1)
    int min[N << 2];
    void push_up(int p) { min[p] = std::min(min[ls], min[rs]); }
    void modify(int p, int l, int r, int pos, int val)
    {
        if (l == r)
        {
            min[p] = val;
            return;
        }
        if (pos <= mid)
            modify(ls, l, mid, pos, val);
        else
            modify(rs, mid + 1, r, pos, val);
        push_up(p);
    }
    int query(int p, int l, int r, int x, int y)
    {
        if (l == x and r == y)
            return min[p];
        int res = INF;
        if (x <= mid)
            res = std::min(res, query(ls, l, mid, x, std::min(mid, y)));
        if (mid < y)
            res = std::min(res, query(rs, mid + 1, r, std::max(mid + 1, x), y));
        return res;
    }
} T;

int main()
{
    n = read();
    for (int i = 1; i <= n; i++)
        a[i] = read();
    for (int i = 1; i <= n; i++)
    {
        if (a[i] != 1)
            vis[1] = 1;
        if (a[i] != 1)
        {
            int pos = T.query(1, 1, n, 1, a[i] - 1);
            if (pos > last[a[i]])
                vis[a[i]] = 1;
        }
        last[a[i]] = i;
        T.modify(1, 1, n, a[i], i);
    }
    for (int i = 2; i <= n + 1; i++)
    {
        int pos = T.query(1, 1, n, 1, i - 1);
        if (pos > last[i])
            vis[i] = 1;
    }
    int ans = 1;
    while (vis[ans] and ans <= n + 1)
        ans++;
    return printf("%d", ans), 0;
}