题目链接

题目描述

有一个 $n \times n$ 的正方形网格,用红色,绿色,蓝色三种颜色染色,求有多少种染色方案使得至少一行或一列是同一种颜色。

结果对 $998244353$ 取模

$n \leq 1e6$

简要做法

颜色相同的方案数,$f(i)$ 钦定,$g(i)$ 恰好

$$f(x,y) = \sum_{i=x}^n\sum_{j=y}^n\dbinom{i}{x}\dbinom{i}{y}g(i,j)$$

高维二项式反演,高维反演即为每个维度的反演系数直接乘起来

$$g(x,y) = \sum_{i=x}^n\sum_{j=y}^n(-1)^{i+j-x-y}\dbinom{i}{x}\dbinom{i}{y}f(i,j)$$

得出答案通式

$$ans=3^{n^2} - g(0,0)$$

大力讨论 $f(i,j)$

$$ \begin{aligned} f(i,j) = \begin{cases} \dbinom{n}{i}\dbinom{n}{j}3^{(n-i)(n-j)+1} &ij \not = 0 \\
\\
\dbinom{n}{i+j}3^{n(n-i-j)+i+j} &ij=0,i+j \not =0 \\
\\
3^{n^2} &i+j=0 \end{cases}\\
\end{aligned} $$

开始计算 $g(0,0)$

$$ \begin{aligned} g(0,0) = & \sum_{i=0}^n\sum_{j=0}^n(-1)^{i+j}f(i,j) \\
= & \sum_{i=1}^n\sum_{j=1}^n(-1)^{i+j}f(i,j) + 2\sum_{i=1}^n(-1)^if(i,0)+f(0,0) \end{aligned} $$

计算第一个和式的复杂度为 $O(n^2)$ ,太慢了,尝试化简

$$ \begin{aligned} \sum_{i=1}^n\sum_{j=1}^n(-1)^{i+j}f(i,j) = & \sum_{i=1}^n\sum_{j=1}^n(-1)^{i+j}\dbinom{n}{i}\dbinom{n}{j}3^{(n-i)(n-j)+1}\\
=&\sum_{i=1}^n(-1)^i\dbinom{n}{i}\sum_{j=1}^n(-1)^{j}\dbinom{n}{j}3^{n^2-(i+j)n+ij+1}\\
=&3^{n^2+1}\sum_{i=1}^n(-1)^i\dbinom{n}{i}3^{-in}\sum_{j=1}^n\dbinom{n}{j}(-3^{i-n})^j\\
=&3^{n^2+1}\sum_{i=1}^n(-1)^i\dbinom{n}{i}3^{-in}((1-3^{i-n})^n-1) \end{aligned} $$

合并可得

$$ \begin{aligned} ans =& 3^{n^2} - g(0,0)\\
=& -3^{n^2+1}\sum_{i=1}^n(-1)^i\dbinom{n}{i}3^{-in}((1-3^{i-n})^n-1) - 2\sum_{i=1}^n(-1)^i\dbinom{n}{i}3^{n(n-i)+i}\\
\end{aligned} $$

参考代码

#include <stdio.h>
#include <algorithm>

#define int long long

int read(int x = 0, int f = 0, char ch = getchar())
{
    while ('0' > ch or ch > '9')
        f = ch == '-', ch = getchar();
    while ('0' <= ch and ch <= '9')
        x = x * 10 + (ch ^ 48), ch = getchar();
    return f ? -x : x;
}

const int N = 1e6 + 5;
const int P = 998244353;

int n, ans, res;
int inv[N], fac[N];

int pow(int x, int k, int res = 1)
{
    for (x = (x + P) % P, k = k % (P - 1); k; k >>= 1, x = x * x % P)
        if (k & 1)
            res = res * x % P;
    return res;
}

int C(int n, int m) { return fac[n] * inv[m] % P * inv[n - m] % P; }

signed main()
{
    n = read(), fac[0] = 1;
    for (int i = 1; i <= n; i++)
        fac[i] = fac[i - 1] * i % P;
    inv[n] = pow(fac[n], P - 2);
    for (int i = n; i; i--)
        inv[i - 1] = inv[i] * i % P;
    for (int i = 1; i <= n; i++)
    {
        int tmp = C(n, i) * pow(pow(3, i * n), P - 2) % P * (pow(1 - pow(pow(3, n - i), P - 2), n) - 1 + P) % P;
        if (i & 1)
            ans = (ans - tmp + P) % P;
        else
            ans = (ans + tmp + P) % P;
    }
    for (int i = 1; i <= n; i++)
    {
        int tmp = C(n, i) * pow(3, n * (n - i) + i) % P;
        if (i & 1)
            res = (res - tmp + P) % P;
        else
            res = (res + tmp + P) % P;
    };
    ans = (ans * -pow(3, n * n + 1) - 2 * res + P) % P;
    printf("%lld", (ans + P) % P);
    return 0;
}