题目链接

题目描述

给定一棵 $n=2m$ 个节点的树,有 $m$ 个白点,$m$ 个黑点

把黑白点两两配对,对于 $k\in[0,n]$,求出恰好有 $k$ 对点有祖孙关系的方案数

$n \le 5*10^3$

简要做法

$f(i)$ 恰好,$g(i)$ 钦定

$$g(k)=\sum_{i=k}^{m}{i \choose k}f(i)$$

$$f(k) = \sum_{i = k} ^ {m} (-1) ^ {i - k} {i \choose k} g(i)$$

树形 DP 求 $g(i)$

$dp(u,i)$ 表示 $u$ 的子树内,选出恰好 $i$ 组有祖孙关系的点对的方案数

假若不选择 $u$,合并子节点即可

假若选择 $u$,则 $dp(u,i+1) += dp(u,i) \times 颜色不同且未配对的点的数量$

$g(k) = dp(1,k) \times (m-k)!$

参考代码

#include <stdio.h>
#include <algorithm>
#include <memory.h>

#define int long long

int read(int x = 0, int f = 0, char ch = getchar())
{
    while ('0' > ch or ch > '9')
        f = ch == '-', ch = getchar();
    while ('0' <= ch and ch <= '9')
        x = x * 10 + (ch ^ 48), ch = getchar();
    return f ? -x : x;
}

const int N = 5e3 + 5;
const int M = N << 1;
const int P = 998244353;

int n, ans;
int a[N];
int head[N], num_edge;
int fac[N], inv[N];
int f[N][N], g[N], siz1[N], siz2[N];

struct Node
{
    int next, to;
} edge[M];

void add_edge(int u, int v) { edge[++num_edge] = Node{head[u], v}, head[u] = num_edge; }

int pow(int x, int k, int res = 1)
{
    for (x %= P; k; x = x * x % P, k >>= 1)
        if (k & 1)
            res = res * x % P;
    return res;
}

int C(int i, int j) { return fac[i] * inv[j] % P * inv[i - j] % P; }

void dfs(int u, int fa)
{
    siz1[u] = f[u][0] = 1, siz2[u] = a[u];
    for (int i = head[u], v; i; i = edge[i].next)
        if ((v = edge[i].to) != fa)
        {
            dfs(v, u);
            for (int i = 0; i <= siz1[u] + siz1[v]; i++)
                g[i] = 0;
            for (int i = 0; i <= std::min(siz1[u], n / 2); i++)
                for (int j = 0; j <= std::min(siz1[v], n / 2 - i); j++)
                    (g[i + j] += f[u][i] * f[v][j] % P) %= P;
            for (int i = 0; i <= siz1[u] + siz1[v]; i++)
                f[u][i] = g[i];
            siz1[u] += siz1[v], siz2[u] += siz2[v];
        }
    for (int i = std::min(siz2[u], siz1[u] - siz2[u]); i; i--)
        if (a[u])
            (f[u][i] += f[u][i - 1] * (siz1[u] - siz2[u] - (i - 1)) % P) %= P;
        else
            (f[u][i] += f[u][i - 1] * (siz2[u] - (i - 1)) % P) %= P;
}

signed main()
{
    n = read(), fac[0] = 1;
    for (int i = 1; i <= n; i++)
        scanf("%1lld", &a[i]);
    for (int i = 1; i <= n; i++)
        fac[i] = fac[i - 1] * i % P;
    inv[n] = pow(fac[n], P - 2);
    for (int i = n; i >= 1; i--)
        inv[i - 1] = inv[i] * i % P;
    for (int i = 1, u, v; i < n; i++)
        u = read(), v = read(), add_edge(u, v), add_edge(v, u);
    dfs(1, 0);
    for (int i = 0; i <= n / 2; i++)
        (f[1][i] *= fac[n / 2 - i]) %= P;
    for (int i = 0; i <= n / 2; printf("%lld\n", ans), ans = 0, i++)
        for (int j = i; j <= n / 2; j++)
            if ((j - i) & 1)
                (ans += -C(j, i) * f[1][j] % P + P) %= P;
            else
                (ans += C(j, i) * f[1][j] % P) %= P;
    return 0;
}