牛客多校第八场


J:Just Jump

题意

给你一个长度为L的路径,你起始点在0号点,每次最少移动d步,而且在$t_i$时刻$p_i$点会遭遇攻击

问你到达L点有多少种方法

思路

首先不考虑攻击,到达L点的方法就是一个简单dp $dp[n]=dp[0]+dp[1]+dp[2]+..+dp[n-d]$

可以用前缀和维护一下,这样dp的复杂度就是O(L)

那么考虑攻击的时候呢,一个明显的思路就是我们已经用dp算出了总的方法数,

那么我们把受到攻击是的路线减去就是不受攻击到达L的方法数

那么当我们在$(t_i,p_i)$受到攻击的方法数是多少呢,我们先把攻击按照位置从小到大排序,时间从小到大排序

首先我们肯定是从$[0,p_i-d]$这些点数转移来的,并且前面的$t_i-1$步每次最少走d步

那么我们在减去$(t_i-1)(d-1)$就变成从最少有1步,那么就是从$[1,p_i-d-(t_i-1)(d-1)]$找$t_i-1$个点,也就是组合数$C_{p_i-t_id+t_i-1}^{t_i-1}$

但是现在又有了一个问题,就是用$C_{p_i-t_id+t_i-1}^{t_i-1}$ 来计算

img1

我们计算了0—>D的所有路径,但是其中包括了从C到D的路径,而这部分路径因为C的被攻击所以是无效的(C被攻击的时间小于D,C在D之前被攻击),所以我们要减去(C—>D)这部分路径

这样我们就把从0开始$t_i$时刻到达攻击点$p_i$ 的路径数$f[i]$全部算了出来,那么结果就是$dp[L]-=\sum\limits_{i=1}^{m}f[i]*dp[L-node[i].p]$

AC代码

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn = 1e7;
const int mod = 998244353;
ll Fac[maxn+5], inv[maxn+5], f[3005];
ll pref[maxn+5], sum[maxn+5];
ll L, d, m;

struct Node{
    ll t, p;
    bool friend operator < (Node a, Node b) {
        if(a.p == b.p) return a.t < b.t;
        return a.p < b.p;
    }
}node[3005];

ll Ksm(ll a, ll b) {
    ll res = 1;
    while(b) {
        if(b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

void init() {
    Fac[0] = 1;
    for (int i = 1; i <= maxn; i ++)
        Fac[i] = (Fac[i-1]*i) % mod;
    inv[maxn] = Ksm(Fac[maxn], mod-2);
    for (int i = maxn-1; i >= 0; i --)
        inv[i] = inv[i+1] * (i+1) % mod;
}

ll Calc(ll t, ll p) {
    if(p-t*d+t-1 < t-1) return 0;
    return Fac[p-t*d+t-1]*inv[t-1]%mod*inv[p-t*d]%mod;
}

int main() {
    init();
    scanf("%lld %lld %lld", &L, &d, &m); 
    for (int i = 1; i <= m; i ++) 
        scanf("%lld %lld", &node[i].t, &node[i].p);
    sort(node+1, node+1+m);
    for (int i = 1; i <= m; i ++) {
        f[i] = Calc(node[i].t, node[i].p);
        for (int j = 1; j < i; j ++) 
            if(node[i].t > node[j].t) 
                f[i] = (f[i] - f[j]*Calc(node[i].t-node[j].t, node[i].p-node[j].p)%mod + mod) % mod;
    }
    sum[0] = 1; pref[0] = 1;
    for (int i = 1; i <= L; i ++) {
        if(i < d) sum[i] = 0;
        else sum[i] = pref[i-d];
        pref[i] = (pref[i-1]+sum[i]) % mod;
    }
    ll ans = sum[L];
    for (int i = 1; i <= m; i ++) 
        ans = (ans - f[i] * sum[L-node[i].p] % mod + mod) % mod;
    printf("%lld\n", ans);
    return 0;
}

文章作者: Mug-9
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Mug-9 !
评论
  目录