BZOJ 3653 谈笑风生(主席树)
文章目录
题目描述
给定一棵 $n$ 个节点的有根树,有 $q$ 组询问
每次询问给出 $(a,k)$,求有多少个三元组 $(a,b,c)$,满足 $a,b$ 都是 $c$ 的祖先,并且 $a,b$ 之间的距离不超过 $k$
$n,q \le 3 \times 10^5$
简要做法
满足 $a,b$ 都是 $c$ 的祖先,$a,b,c$ 肯定是在一条链上的
分类讨论 $b$ 的位置
-
$b$ 在 $a$ 上方
贡献为 $\sum_b{(size_a-1)}$
$b$ 有 $min(depth_a,k)$ 种取法
对于每个钦定的 $b$,$a$ 已经确定了,$c$ 可以在 $a$ 的子树内任选(不能和 $a$ 重合),那么有 $size_a - 1$ 中取法
所以贡献为 $min(depth_a,k) \times (size_a - 1)$ 种
-
$b$ 在 $a$ 下方
总共 $\sum_b{(size_b-1)}$
考虑统计有多少个 $b$,满足 $depth_b \in [depth_a+1,depth_a+k]$
把 $dfn$ 和 $depth$ 看成 $x$ 和 $y$ 轴,那么就可以把树变换为一个个二维数点,每个点 $x$ 的权值为 $(size_x-1)$
每次询问等价于在这个二维平面上,查询 $x \in [dfn_a+1,dfn_a+size_a-1]$,$y \in [depth_a+1,depth_a+k]$,这个矩形内点集的权值和
参考代码
#include <stdio.h>
#include <algorithm>
#include <memory.h>
#define int long long
const int N = 3e5 + 5;
const int M = N << 1;
int n, q;
int head[N], num_edge, max_depth;
int size[N], depth[N], dfn[N], idx[N];
struct Node
{
int next, to;
} edge[M];
void add_edge(int u, int v) { edge[++num_edge] = Node{head[u], v}, head[u] = num_edge; }
int read()
{
int x = 0, f = 1;
char ch = getchar();
while ('0' > ch or ch > '9')
f = ch == '-' ? -1 : 1, ch = getchar();
while ('0' <= ch and ch <= '9')
x = x * 10 + ch - 48, ch = getchar();
return x * f;
}
struct HJT_Tree
{
#define mid ((l + r) >> 1)
int nodecnt, rt[N << 5], L[N << 5], R[N << 5], val[N << 5];
void modify(int &p, int pre, int l, int r, int pos, int x)
{
p = ++nodecnt;
L[p] = L[pre], R[p] = R[pre], val[p] = val[pre] + x;
if (l < r)
{
if (pos <= mid)
modify(L[p], L[pre], l, mid, pos, x);
else
modify(R[p], R[pre], mid + 1, r, pos, x);
}
}
int query(int u, int v, int l, int r, int x, int y)
{
if (l == x and r == y)
return val[v] - val[u];
int res = 0;
if (x <= mid)
res += query(L[u], L[v], l, mid, x, std::min(mid, y));
if (mid < y)
res += query(R[u], R[v], mid + 1, r, std::max(mid + 1, x), y);
return res;
}
} T;
void dfs1(int u, int fa)
{
size[u] = 1, depth[u] = depth[fa] + 1, max_depth = std::max(max_depth, depth[u]), dfn[u] = ++dfn[0], idx[dfn[u]] = u;
for (int i = head[u], v; i; i = edge[i].next)
if ((v = edge[i].to) != fa)
dfs1(v, u), size[u] += size[v];
}
signed main()
{
n = read(), q = read();
for (int i = 1, u, v; i < n; i++)
u = read(), v = read(), add_edge(u, v), add_edge(v, u);
dfs1(1, 0);
for (int i = 1; i <= n; i++)
T.modify(T.rt[dfn[idx[i]]], T.rt[dfn[idx[i]] - 1], 1, max_depth, depth[idx[i]], size[idx[i]] - 1);
for (int res; q--; printf("%lld\n", res))
{
int a = read(), k = read();
res = std::min(k, depth[a] - 1) * (size[a] - 1);
int l = dfn[a] + 1, r = dfn[a] + size[a] - 1;
res += T.query(T.rt[l - 1], T.rt[r], 1, max_depth, depth[a] + 1, std::min(max_depth, depth[a] + k));
}
return 0;
}