A: Blank
题意
有 $n (n \leq 100)$ 个格子,向其中填入 $0、1、2、3 $这$4$个数,但是有 $m ( m ≤ 100)$ 个限制
限制 $l$ $ r$ $x$ :表示 $l ~ r$ 的格子内不同的数的个数为$x$
要求满足所有限制的方案有多少种?
思路
我们首先设$dp[i][j][k][r]$为这$0,1,2,3$四个数字的最后一次出现的位置,$dp$值为方案数
那么转移可以这样写一下:
$dp[cur][j][k][r] += dp[i][j][k][r], dp[i][cur][k][r] += dp[i][j][k][r]$
$dp[i][j][cur][r] += dp[i][j][k][r], dp[i][j][k][cur] += dp[i][j][k][r]$
因为$i,j,k,r$互不相同, 且当位一定为一个数字并且相互之间有大小顺序,那么我们把$dp$按照大小来转移的话
还是$dp[cur][i][j][k]$ 其中$cur \geq i \geq j \geq k$
那么转移就变成
$dp[cur+1][i][j][k]+=dp[cur][i][j][k], dp[cur+1][cur][j][k] += dp[cur][i][j][k]$
$dp[cur+1][cur][i][k] += dp[cur][i][j][k], dp[cur+1][cur][i][j] += dp[cur][i][j]$
我们不必要区分$0,1,2,3$对应的是哪一个,因为这对结果没影响
这样的$dp$数组太大,我们可以用滚动数组来优化一下空间
AC代码
实测$dp$数组降序会$T$,可能是因为$dp$过程中地址变换太大造成超时.
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn = 1e2 + 7;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
typedef pair<int, int> pis;
vector<pis> lo[maxn];
ll dp[maxn][maxn][maxn][2];
//dp[i][j][k][cur] 升序
void add(ll &a, ll b) {
a = a + b;
if(a > mod) a -= mod;
if(a < 0) a += mod;
}
int main() {
int t;
scanf("%d", &t);
while(t --) {
int n, m;
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i ++) {
lo[i].clear();
lo[i].push_back(pis{i, 1});
}
for (int i = 1; i <= m; i ++) {
int l, r, x;
scanf("%d %d %d", &l, &r, &x);
lo[r].push_back(pis{l, x});
}
memset(dp, 0, sizeof(dp));
dp[0][0][0][0] = 1;
for (int cur = 1; cur <= n; cur ++) {
int np = cur & 1;
for (int i = 0; i <= cur; i ++)
for (int j = i; j <= cur; j ++)
for (int k = j; k <= cur; k ++)
dp[i][j][k][np] = 0;
for (int i = 0; i <= cur; i ++)
for (int j = i; j <= cur; j ++)
for (int k = j; k <= cur; k ++) {
/*add(dp[i][k][cur-1][np], dp[i][j][k][np^1]);
地址跨越比add(dp[np][cur-1][k][i], dp[np^1][k][j][i]);
要大,可能是造成超时的原因
*/
add(dp[j][k][cur-1][np], dp[i][j][k][np^1]);
add(dp[i][k][cur-1][np], dp[i][j][k][np^1]);
add(dp[i][j][cur-1][np], dp[i][j][k][np^1]);
add(dp[i][j][k][np], dp[i][j][k][np^1]);
}
for (int i = 0; i <= cur; i ++)
for (int j = i; j <= cur; j ++)
for (int k = j; k <= cur; k ++)
for (pis it: lo[cur]) {
int l = it.first, r = cur, x = it.second;
int cnt = (i >= l) + (j >= l) + (k >= l) + 1;
if(cnt != x) dp[i][j][k][np] = 0;
}
}
ll ans = 0;
for (int i = 0; i <= n; i ++)
for (int j = i; j <= n; j ++)
for (int k = j; k <= n; k ++) add(ans, dp[i][j][k][n&1]);
printf("%lld\n", ans);
}
return 0;
}
L: Sequence
题意
给一个长度为n的数组,有m次操作,操作有3种,给一个x,每次改变序列的值$b_i=\sum\limits_{j=i-k*x}a_j$
求改变完了的序列的$(i\times a[i])$值的异或和
思路
通过打表观察可以发现,一种操作多次操作就是把序列$a$和组合数序列进行卷积,然后就直接用ntt就行了
AC代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn = 5e5 + 7;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
typedef pair<int, int> pis;
#define g 3
#define Mod(x) ((x)>=mod?(x)-mod:(x))
ll rnk[maxn];
ll a[maxn], b[maxn];
ll Ksm(ll a, ll b) {
ll res = 1;
while(b) {
if(b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
ll Fac[1000005], inv[1000005];
void FacPre() {
inv[0] = Fac[0] = 1;
for (int i = 1; i <= 1000000; i ++)
Fac[i] = 1ll * Fac[i-1] * i % mod;
inv[1000000] = Ksm(Fac[1000000], mod-2);
for (int i = 999999; i >= 1; i --)
inv[i] = 1ll * inv[i+1] * (i+1) % mod;
}
ll C(int n, int m) {
if(m > n) return 0;
return 1ll * Fac[n] * inv[m] % mod * inv[n-m] % mod;
}
void ntt(long long *a, int op, int n) {
for (int i = 0; i < n; i ++)
if(i < rnk[i]) swap(a[i], a[rnk[i]]);
for (int i = 2; i <= n; i <<= 1) {
int nw = Ksm(g, (mod-1)/i);
if(op == -1) nw = Ksm(nw, mod-2);
for (int j = 0, m = i >> 1; j < n; j += i)
for (int k = 0, w = 1; k < m; k ++) {
int t = 1ll * a[j+k+m] * w % mod;
a[j+k+m] = Mod(a[j+k]-t+mod);
a[j+k] = Mod(a[j+k]+t);
w = 1ll * w * nw % mod;
}
}
if(op == -1)
for (int i = 0, inv = Ksm(n, mod-2); i < n; i ++)
a[i] = 1ll * a[i] * inv % mod;
}
void solve(ll *a, ll *b, int len) {
int n = 1, lim = 0;
while(n <= len + len) n <<= 1, lim++;
for (int i = 0; i < n; i ++)
rnk[i] = (rnk[i>>1]>>1) | ((i&1) << (lim-1));
ntt(a, 1, n); ntt(b, 1, n);
for (int i = 0; i < n; i ++)
a[i] = (1ll * a[i] * b[i]) % mod;
ntt(a, -1, n);
for (int i = len; i < n; i ++) a[i] = 0;
}
int cnt[5];
int main() {
FacPre();
int t;
scanf("%d", &t);
while(t --) {
memset(cnt, 0, sizeof(cnt));
int n, m;
scanf("%d %d", &n, &m);
for (int i = 0; i < n; i ++)
scanf("%lld", &a[i]);
for (int i = 1, op; i <= m; i ++) {
scanf("%d", &op);
cnt[op] ++;
}
for (int i = 1; i <= 3; i ++) {
memset(b, 0, sizeof(b));
for (int j = 0; j * i < n; j ++)
b[j*i] = C(cnt[i]-1+j, j);
if(cnt[i] == 0) b[0] = 1;
solve(a, b, n);
}
ll ans = 0;
for (int i = 0; i < n; i ++) ans = ans ^ (1ll * (i+1) * a[i]);
printf("%lld\n", ans);
}
return 0;
}