「HDU 6314」Matrix

Description

题目链接:HDU 6314

对于一个 $n\times m$ 的网格,每个格子只能涂上黑色或白色。求所有涂色方案中,至少有 $A$ 行 $B$ 列为黑色的方案数(对 $998244353$ 取模)。

数据范围:$1\le n,m,A,B\le 3000$


Solution

显然,行和列是相同的,于是我们可以把列去掉,记 $Ans(i)$ 表示至少有 $i$ 行全黑的方案数。

接下来考虑容斥,有以下式子

其中 $f_i$ 是 $Ans(a)$ 所对应的一个未知的容斥系数。

故我们只需要考虑如何求 $f_i$

考虑任意一个选了 $i$ 行且这 $i$ 行全黑的方案在上面的式子里计算的次数(每个方案最后应该只会被计算一次)

尝试推导 $f_i$ 的递推式

这样我们就可以在 $O(n^2)$ 的时间内递推出容斥系数。

因为行和列的问题是等价的,所以可以用相同的方法求出列的容斥系数。

记 $Ans(a,b)$ 表示至少 $a$ 行和 $b$ 列全黑的方案数,则有

时间复杂度:$O(n^2+m^2+nm)$


Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <cstdio>
const int N=3000;
const int mod=998244353;
int n,m,a,b,fa[N+5],fb[N+5],p[N*N+5],c[N+5][N+5];
void init() {
p[0]=1;
for(int i=1;i<=N*N;++i) {
p[i]=2*p[i-1];
if(p[i]>=mod) p[i]-=mod;
}
for(int i=0;i<=N;++i) c[i][0]=1;
for(int i=1;i<=N;++i) for(int j=0;j<=i;++j) {
c[i][j]=c[i-1][j-1]+c[i-1][j];
if(c[i][j]>=mod) c[i][j]-=mod;
}
}
void solve() {
fa[a]=1;
for(int i=a+1;i<=n;++i) {
int sum=0;
for(int j=a;j<i;++j) sum=(sum+1LL*c[i-1][j-1]*fa[j])%mod;
fa[i]=-sum;
}
fb[b]=1;
for(int i=b+1;i<=m;++i) {
int sum=0;
for(int j=b;j<i;++j) sum=(sum+1LL*c[i-1][j-1]*fb[j])%mod;
fb[i]=-sum;
}
}
int main() {
init();
while(~scanf("%d%d%d%d",&n,&m,&a,&b)) {
solve();
int ans=0;
for(int i=a;i<=n;++i) for(int j=b;j<=m;++j)
ans=(ans+1LL*fa[i]*c[n][i]%mod*fb[j]%mod*c[m][j]%mod*p[(n-i)*(m-j)])%mod;
printf("%d\n",(ans%mod+mod)%mod);
}
return 0;
}

Extended

上述代码的运行时间为 $700\text{ms}$ 左右,考虑如何优化。

此处,我们还是考虑只有行的情况。

如果我们强制有 $b$ 行为黑色,那么它对答案有 $\sum_{i=a}^b C_b^i$ 次的贡献(被计算进答案的次数)。

而事实上,我们想让它的贡献为(系数) $1$,考虑如下公式

故容斥系数为 $(-1)^{i-a}\times C_{i-1}^{a-1}$

通过优化,程序运行时间可以压到 $400\text{ms}$ 左右。

时间复杂度:$O(nm)$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <cstdio>
const int N=3005;
const int M=9000005;
const int P=998244353;
int n,m,a,b,o[N],x[N],y[N],p[M],c[N][N];
void init() {
o[0]=1;
for(int i=1;i<=3000;++i) o[i]=P-o[i-1];
p[0]=1;
for(int i=1;i<=9000000;++i) p[i]=2*p[i-1]%P;
for(int i=0;i<=3000;++i) {
c[i][0]=c[i][i]=1;
for(int j=1;j<i;++j) c[i][j]=(c[i-1][j-1]+c[i-1][j])%P;
}
}
int main() {
init();
while(~scanf("%d %d %d %d",&n,&m,&a,&b)) {
for(int i=a;i<=n;++i) x[i]=1LL*o[i-a]*c[n][i]%P*c[i-1][a-1]%P;
for(int i=b;i<=m;++i) y[i]=1LL*o[i-b]*c[m][i]%P*c[i-1][b-1]%P;
int ans=0;
for(int i=a;i<=n;++i) for(int j=b;j<=m;++j) ans=(ans+1LL*x[i]*y[j]%P*p[(n-i)*(m-j)])%P;
printf("%d\n",ans);
}
return 0;
}
0%