牛客多校第四场


E.triples ll

题意

让你用$n$个3的倍数,把$a$或出来,问你有几种方案,对998244353取模

思路

在二进制中$1,4,16,\mod 3余1$, 而 $2,8,32 \mod 3余2$

首先如果$a$中为$1$的二进制位在$b$中也都为$1$,那么就称$a$是$b$的子集

我们用$dp[i][j]$来表示二进制位$mod$ $3$ 余$1$的个数为$i$,$mod$ $3$ 余$2$的个数为$j$并且所有是3的倍数的子集个数

我们可以先一个$(logn)^4$求出所有的$dp$值

void Fac() {
    for (int i = 0; i < maxn; i ++) C[i][0] = 1;
    for (int i = 1; i < maxn; i ++) 
        for (int j = 1; j <= i; j ++) 
            C[i][j] = (C[i-1][j] + C[i-1][j-1]) % mod;
}

void init() {
    Fac();
    for (int x = 0; x < maxn; x ++) 
        for (int y = 0; y < maxn; y ++) 
            for (int i = 0; i <= x; i ++)
                for (int j = 0; j <= y; j ++) 
                    if((i+2*j)%3==0) dp[x][y] = (dp[x][y] + C[x][i] * C[y][j] % mod) % mod;
    dp[0][0] = 1;
}

如果要求或出来的结果是$a$的子集,那么方案数就是$(a的子集的个数)^n$

但是题目要求或出来的结果是$a$,那我们就要容斥一下了

打个比方,现在$a$的$\mod2$的个数为$1$,$\mod1$的个数为$2$

那么$dp[2][1]$所代表的是$<2,1>$的子集的个数,但有些子集在相或$n$次以后得不到a,这时候就要减掉那些不能或到$a$的子集

img1

$dp[2][1]$就像上图中由 A-E,​A-D,A-C,B-E,B-D,B-C 组成,而能或出来$a$的只有A-E,那么我们就要把其余不满足的减掉,也就是减掉$dp[1][1]+dp[2][0]$,发现减多了,我们要加上$dp[1][0]+dp[0][1]$,后面就跟容斥一样,多写几个样例就会发现当$(num1+num2-i-j)\mod2$时是减,其他是加。

注意:减的时候减掉的是不符合个数的$n$次幂

AC代码

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn = 70;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;

ll C[maxn][maxn], dp[maxn][maxn];
int cnt[2], o;

ll Ksm(ll a, ll b) {
    ll ret = 1;
    while(b) {
        if(b & 1) ret = ret * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ret;
}

void Fac() {
    for (int i = 0; i < maxn; i ++) C[i][0] = 1;
    for (int i = 1; i < maxn; i ++) 
        for (int j = 1; j <= i; j ++) 
            C[i][j] = (C[i-1][j] + C[i-1][j-1]) % mod;
}

void init() {
    Fac();
    for (int x = 0; x < maxn; x ++) 
        for (int y = 0; y < maxn; y ++) 
            for (int i = 0; i <= x; i ++)
                for (int j = 0; j <= y; j ++) 
                    if((i+2*j)%3==0) dp[x][y] = (dp[x][y] + C[x][i] * C[y][j] % mod) % mod;
    dp[0][0] = 1;
}

int main() {
    init();
    int t;
    scanf("%d", &t);
    while(t --) {
        ll n, a;
        scanf("%lld %lld", &n, &a);
        cnt[0] = cnt[1] = o = 0;
        while(a) {
            if(a & 1) cnt[o] ++;
            o ^= 1; a /= 2;
        }
        ll ans = 0;
        for (int i = 0; i <= cnt[0]; i ++) 
            for (int j = 0; j <= cnt[1]; j ++) {
                ll f = C[cnt[0]][i] * C[cnt[1]][j] % mod * Ksm(dp[i][j], n) % mod;
                if((cnt[0] + cnt[1] - i - j) & 1) f *= -1;
                ans = (ans + f + mod) % mod; 
            }        
        printf("%lld\n", ans);
    }
    return 0;
}

文章作者: Mug-9
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Mug-9 !
评论
  目录