题目链接

题目描述

给定一棵树,割断每一条边都有代价,每次询问会给定一些点,求用最少的代价使所有给定点都和1号节点不连通

虚树这东西就是每次只把有用的点留下,不用的点删了

首先你树形 dp 要好,不然这东西学了也没用

关于如何构建虚树

单调栈建虚树

每次用单调栈维护一条链

OI wiki 上讲的很详细,还带图片

这做法虽然跑得快,但是写着好不爽,我选择欧拉序

#include <stdio.h>
#include <algorithm>
#include <memory.h>

typedef long long ll;

const int N = 25e4 + 5;
const int M = N << 1;
const ll INF = 1LL << 60;

int n, m;
int dfn[N], idx[N], top[N], son[N], fa[N], depth[N], size[N];
ll min[N], f[N];
int a[N], stack[N], ceil;

struct Graph
{
    int head[N], num_edge;
    struct Node
    {
        int next, to, dis;
    } edge[M];
    void add_edge(int u, int v, int dis = 0) { edge[++num_edge] = Node{head[u], v, dis}, head[u] = num_edge; }
    void dfs1(int u, int fa)
    {
        ::fa[u] = fa, size[u] = 1, depth[u] = depth[fa] + 1;
        for (int i = head[u], v; i; i = edge[i].next)
            if ((v = edge[i].to) != fa)
                min[v] = std::min(min[u], (ll)edge[i].dis), dfs1(v, u), size[u] += size[v], son[u] = size[v] > size[son[u]] ? v : son[u];
    }
    void dfs2(int u)
    {
        dfn[u] = ++dfn[0], idx[dfn[u]] = u, top[u] = u == son[fa[u]] ? top[fa[u]] : u;
        if (son[u])
            dfs2(son[u]);
        for (int i = head[u], v; i; i = edge[i].next)
            if ((v = edge[i].to) != son[u] and v != fa[u])
                dfs2(v);
    }
    int lca(int x, int y)
    {
        for (; top[x] != top[y]; x = fa[top[x]])
            if (dfn[top[x]] < dfn[top[y]])
                std::swap(x, y);
        return dfn[x] < dfn[y] ? x : y;
    }
    void insert(int u)
    {
        if (ceil == 1)
        {
            if (u != 1)
                stack[++ceil] = u;
            return;
        }
        int x = lca(u, stack[ceil]);
        if (x == stack[ceil])
            return;
        for (; ceil > 1 and dfn[stack[ceil - 1]] >= dfn[x]; ceil--)
            add_edge(stack[ceil - 1], stack[ceil]);
        if (stack[ceil] != x)
            add_edge(x, stack[ceil]), stack[ceil] = x;
        stack[++ceil] = u;
    }
    void dp(int u)
    {
        if (!head[u])
        {
            f[u] = min[u];
            return;
        }
        f[u] = 0;
        for (int i = head[u], v; i; i = edge[i].next)
            dp(v = edge[i].to), f[u] += f[v];
        f[u] = std::min(f[u], min[u]);
        head[u] = 0;
    }
} G, VT;

bool cmp(int a, int b) { return dfn[a] < dfn[b]; }

int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while ('0' > ch or ch > '9')
        f = ch == '-' ? -1 : 1, ch = getchar();
    while ('0' <= ch and ch <= '9')
        x = x * 10 + ch - 48, ch = getchar();
    return x * f;
}

int main()
{
    n = read();
    for (int i = 1, u, v, dis; i < n; i++)
        u = read(), v = read(), dis = read(), G.add_edge(u, v, dis), G.add_edge(v, u, dis);
    min[1] = INF, G.dfs1(1, 0), G.dfs2(1);
    for (m = read(); m--;)
    {
        int k = read();
        for (int i = 1; i <= k; i++)
            a[i] = read();
        std::sort(a + 1, a + 1 + k, cmp);
        VT.num_edge = 0;
        stack[ceil = 1] = 1;
        for (int i = 1; i <= k; i++)
            VT.insert(a[i]);
        for (; ceil > 1; ceil--)
            VT.add_edge(stack[ceil - 1], stack[ceil]);
        VT.dp(1);
        printf("%lld\n", f[1]);
    }
    return 0;
}

欧拉序建虚树

好理解,写起来容易

先把给定的点按照欧拉序排个序

那么虚树里肯定有且只有这些点和数组里相邻两个点的 lca 和 根节点 (记得去重)

然后再把所有点复制一遍,用负数表示它是出栈的点

然后按照欧拉序再排个序,这数组就变成我们在虚树上 dfs 时的栈啦

按照这个栈做就行了,连虚树的邻接表都不用搞,方便的 eb

稍微慢了一点,应该不会卡吧(大概

#include <stdio.h>
#include <algorithm>
#include <memory.h>

typedef long long ll;

const int N = 3e6 + 5;
const int M = N << 1;
const ll INF = 1LL << 60;

int n, m;
int dfin[N], dfou[N], top[N], son[N], fa[N], depth[N], size[N];
long long f[N], min[N];
int head[N], num_edge;
int a[N], stack[N];
bool b[N];

struct Node
{
    int next, to, dis;
} edge[M];

void add_edge(int u, int v, int dis = 0) { edge[++num_edge] = Node{head[u], v, dis}, head[u] = num_edge; }

int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while ('0' > ch or ch > '9')
        f = ch == '-' ? -1 : 1, ch = getchar();
    while ('0' <= ch and ch <= '9')
        x = x * 10 + ch - 48, ch = getchar();
    return x * f;
}

void dfs1(int u, int fa)
{
    ::fa[u] = fa, size[u] = 1, depth[u] = depth[fa] + 1;
    for (int i = head[u], v; i; i = edge[i].next)
        if ((v = edge[i].to) != fa)
            min[v] = std::min(min[u], (ll)edge[i].dis), dfs1(v, u), size[u] += size[v], son[u] = size[v] > size[son[u]] ? v : son[u];
}

void dfs2(int u)
{
    dfin[u] = ++dfin[0], top[u] = u == son[fa[u]] ? top[fa[u]] : u;
    if (son[u])
        dfs2(son[u]);
    for (int i = head[u], v; i; i = edge[i].next)
        if ((v = edge[i].to) != fa[u] and v != son[u])
            dfs2(v);
    dfou[u] = ++dfin[0];
}

int lca(int x, int y)
{
    for (; top[x] != top[y]; x = fa[top[x]])
        if (dfin[top[x]] < dfin[top[y]])
            std::swap(x, y);
    return dfin[x] < dfin[y] ? x : y;
}

bool cmp(int x, int y)
{
    int k1 = x > 0 ? dfin[x] : dfou[-x];
    int k2 = y > 0 ? dfin[y] : dfou[-y];
    return k1 < k2;
}

int main()
{
    n = read();
    for (int i = 1, u, v, dis; i < n; i++)
        u = read(), v = read(), dis = read(), add_edge(u, v, dis), add_edge(v, u, dis);
    min[1] = INF, dfs1(1, 0), dfs2(1);
    for (m = read(); m--;)
    {
        int k = read();
        for (int i = 1; i <= k; i++)
            a[i] = read(), b[a[i]] = true, f[a[i]] = min[a[i]];
        std::sort(a + 1, a + 1 + k, cmp);
        for (int i = 1; i < k; i++)
        {
            int LCA = lca(a[i], a[i + 1]);
            if (!b[LCA])
                a[++k] = LCA, b[LCA] = true;
        }
        for (int i = 1, tmp = k; i <= tmp; i++)
            a[++k] = -a[i];
        if (!b[1])
            a[++k] = 1, a[++k] = -1;
        std::sort(a + 1, a + 1 + k, cmp);
        for (int i = 1, top = 0; i <= k; i++)
            if (a[i] > 0)
                stack[++top] = a[i];
            else
            {
                int u = stack[top--];
                if (u != 1)
                {
                    int fa = stack[top];
                    f[fa] += std::min(f[u], min[u]);
                }
                else
                    printf("%lld\n", f[1]);
                f[u] = b[u] = 0;
            }
    }
    return 0;
}