题目链接

题目描述

给定一棵 $n$ 个节点的有根树,有 $q$ 组询问

每次询问给出 $(a,k)$,求有多少个三元组 $(a,b,c)$,满足 $a,b$ 都是 $c$ 的祖先,并且 $a,b$ 之间的距离不超过 $k$

$n,q \le 3 \times 10^5$

简要做法

满足 $a,b$ 都是 $c$ 的祖先,$a,b,c$ 肯定是在一条链上的

分类讨论 $b$ 的位置

  • $b$ 在 $a$ 上方

    贡献为 $\sum_b{(size_a-1)}$

    $b$ 有 $min(depth_a,k)$ 种取法

    对于每个钦定的 $b$,$a$ 已经确定了,$c$ 可以在 $a$ 的子树内任选(不能和 $a$ 重合),那么有 $size_a - 1$ 中取法

    所以贡献为 $min(depth_a,k) \times (size_a - 1)$ 种

  • $b$ 在 $a$ 下方

    总共 $\sum_b{(size_b-1)}$

    考虑统计有多少个 $b$,满足 $depth_b \in [depth_a+1,depth_a+k]$

    把 $dfn$ 和 $depth$ 看成 $x$ 和 $y$ 轴,那么就可以把树变换为一个个二维数点,每个点 $x$ 的权值为 $(size_x-1)$

    每次询问等价于在这个二维平面上,查询 $x \in [dfn_a+1,dfn_a+size_a-1]$,$y \in [depth_a+1,depth_a+k]$,这个矩形内点集的权值和

参考代码

#include <stdio.h>
#include <algorithm>
#include <memory.h>

#define int long long

const int N = 3e5 + 5;
const int M = N << 1;

int n, q;
int head[N], num_edge, max_depth;
int size[N], depth[N], dfn[N], idx[N];

struct Node
{
    int next, to;
} edge[M];

void add_edge(int u, int v) { edge[++num_edge] = Node{head[u], v}, head[u] = num_edge; }

int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while ('0' > ch or ch > '9')
        f = ch == '-' ? -1 : 1, ch = getchar();
    while ('0' <= ch and ch <= '9')
        x = x * 10 + ch - 48, ch = getchar();
    return x * f;
}

struct HJT_Tree
{
#define mid ((l + r) >> 1)
    int nodecnt, rt[N << 5], L[N << 5], R[N << 5], val[N << 5];
    void modify(int &p, int pre, int l, int r, int pos, int x)
    {
        p = ++nodecnt;
        L[p] = L[pre], R[p] = R[pre], val[p] = val[pre] + x;
        if (l < r)
        {
            if (pos <= mid)
                modify(L[p], L[pre], l, mid, pos, x);
            else
                modify(R[p], R[pre], mid + 1, r, pos, x);
        }
    }
    int query(int u, int v, int l, int r, int x, int y)
    {
        if (l == x and r == y)
            return val[v] - val[u];
        int res = 0;
        if (x <= mid)
            res += query(L[u], L[v], l, mid, x, std::min(mid, y));
        if (mid < y)
            res += query(R[u], R[v], mid + 1, r, std::max(mid + 1, x), y);
        return res;
    }
} T;

void dfs1(int u, int fa)
{
    size[u] = 1, depth[u] = depth[fa] + 1, max_depth = std::max(max_depth, depth[u]), dfn[u] = ++dfn[0], idx[dfn[u]] = u;
    for (int i = head[u], v; i; i = edge[i].next)
        if ((v = edge[i].to) != fa)
            dfs1(v, u), size[u] += size[v];
}

signed main()
{
    n = read(), q = read();
    for (int i = 1, u, v; i < n; i++)
        u = read(), v = read(), add_edge(u, v), add_edge(v, u);
    dfs1(1, 0);
    for (int i = 1; i <= n; i++)
        T.modify(T.rt[dfn[idx[i]]], T.rt[dfn[idx[i]] - 1], 1, max_depth, depth[idx[i]], size[idx[i]] - 1);
    for (int res; q--; printf("%lld\n", res))
    {
        int a = read(), k = read();
        res = std::min(k, depth[a] - 1) * (size[a] - 1);
        int l = dfn[a] + 1, r = dfn[a] + size[a] - 1;
        res += T.query(T.rt[l - 1], T.rt[r], 1, max_depth, depth[a] + 1, std::min(max_depth, depth[a] + k));
    }
    return 0;
}