Cydiater

「Codeforces 896D」Nephren Runs a Cinema

给定一个长度为$n$的序列,每个位置上的值为$[-1, 1]$,要求有多少种方案,满足任意前缀和均非负,且序列的元素和在$[l,r]$,答案对一个非质数取余后输出。

首先我们先考虑,每个位置上的值是${ -1, 1 }$时的答案。然后我们强制把$l,r$的奇偶性置为和$n$一样的,因为可以发现只有奇偶性相同的取值才存在合法的构造方案。然后,因为有前缀均非负的要求,而类似的构造是卡特兰数的方案,而卡特兰数的方案是$C_{2n}^{n} - C_{2n}^{n - 1}$,那么能不能扩展到这个问题呢。

考虑从卡特兰数的构造的证明出发,考虑满足序列长度为$n$,首项是$1$,末项是$m$的方案数。首先考虑所有可能的方案即:$C_n^{\frac{n + m}{2}}$,考虑剔除掉所有不合法的方案即存在前缀和为负的情况。我们考虑把一个$1$变成$-1$后对应的所有方案即 $C_n^{\frac{n + m}{2} + 1}$,这个方案即对应了所有不合法的情况,我们取第一个满足前缀为负的前缀,然后把这段前缀后面的一段取相反数即可,然后我们把$m \in [l, r]$的所有方案累加起来,就是

$$C_{n}^{\frac{n + l}{2}} - C_{n}^{\frac{n + r}{2} + 1}$$

当然,证明也可以从网格路径方案数来考虑,这里不再赘述。

这样,所有的方案我们就可以在$O(\log)$的时间内处理出来,然后我们考虑在中间填上$0$就行了。

然后,因为模数不是大质数,我们并不能方便的预处理出组合数。这个就很模板了,Lucas + CRT 一下就行了。

#include <bits/stdc++.h>

using namespace std;

#define ll             long long
#define db            double
#define up(i,j,n)        for (int i = j; i <= n; i++)
#define down(i,j,n)    for (int i = j; i >= n; i--)
#define cadd(a,b)        a = add (a, b)
#define cpop(a,b)        a = pop (a, b)
#define cmul(a,b)        a = mul (a, b)
#define pr            pair<int, int>
#define fi            first
#define se            second
#define SZ(x)        (int)x.size()
#define bin(i)        (1 << (i))
#define Auto(i,node)    for (int i = LINK[node]; i; i = e[i].next)

template<typename T> inline void cmax(T & x, T y){y > x ? x = y : 0;}
template<typename T> inline void cmin(T & x, T y){y < x ? x = y : 0;}

int mod, ss;

inline int add(ll a, ll b){a += b; return a >= mod ? a - mod : a;}
inline int pop(ll a, ll b){a -= b; return a < 0 ? a + mod : a;}
inline int mul(ll a, ll b){return a * b % mod;}

int qpow(int a, int b){
    int c = 1;
    while (b) {
        if (b & 1) cmul(c, a);
        cmul(a, a); b >>= 1;
    }
    return c;
}

const int MAXN = 1e5 + 5;
const int oo = 0x3f3f3f3f;

int fac[MAXN], inv[MAXN], cnt[MAXN], c, d, A, M;

int gcd(int a, int b){return !b ? a : gcd(b, a % b);}

void exgcd(int a, int b, int &x, int &y){
    if (!b) {x = 1; y = 0; return;}
    exgcd(b, a % b, y, x);
    y -= a / b * x;
}

int Inv(int a){
    int x, y;
    exgcd(a, mod, x, y);
    return (x % mod + mod) % mod;
}

int C(int a, int b){
    if (a < 0 || b < 0 || a < b)     return 0;
    if (mod == ss) {
        if (a < mod) return mul(fac[a], mul(inv[b], inv[a - b]));
        return mul(C(a / mod, b / mod), C(a % mod, b % mod));
    }
    int res = 1, cnt = 0;
    for (int i = a; i; i /= ss) {
        cnt += i / ss;
        cmul(res, fac[i]);
    }
    for (int i = b; i; i /= ss) {
        cnt -= i / ss;
        cmul(res, Inv(fac[i]));
    }
    for (int i = a - b; i; i /= ss) {
        cnt -= i / ss;
        cmul(res, Inv(fac[i]));
    }
    cmul(res, qpow(ss, cnt));
    return res;
}

int calc(int n, int l, int r){
    if (ss == mod) {
        fac[0] = inv[0] = inv[1] = 1;
        up (i, 1, n) fac[i] = mul(i, fac[i - 1]);
        up (i, 2, n) inv[i] = mul(mod - mod / i, inv[mod % i]);
        up (i, 1, n) cmul(inv[i], inv[i - 1]);
    }else {
        fac[0] = 1;
        up (i, 1, n) {
            if (i % ss == 0) fac[i] = fac[i - 1];
            else fac[i] = mul(i % mod, fac[i - 1]);
        }
    }
    int sum = 0;
    up (i, 0, n) {
        int dl = l, dr = r, m = n - i;
        cmin(dr, m);
        if (m & 1) {
            if (!(dl & 1)) dl++;
            if (!(dr & 1)) dr--;
        }else {
            if (dl & 1) dl++;
            if (dr & 1) dr--;
        }
        if (dl > dr) {
            cadd(sum, 0);
            continue;
        }
        int cl = (m + dl) / 2, cr = (m + dr) / 2 + 1;
        cadd(sum, mul(C(n, i), pop(C(m, cl), C(m, cr))));
    }
    return sum;
}

void CRT(int a, int m){
    int x, y, c = a - A, d = gcd(-m, M);
    exgcd(M, -m, x, y);
    int t = -m / d;
    x = ((1LL * x * (c / d) % t) + t) % t;
    int tmp = M * m;
    A = ((A + 1LL * x * M) % tmp + tmp) % tmp;
    M = tmp;
}

int main(){
    int n, p, l, r;
    scanf("%d%d%d%d", &n, &p, &l, &r);
    int sq = sqrt(p);
    A = 0; M = 1;
    up (i, 2, sq) if (p % i == 0) {
        mod = 1; ss = i;
        while (p % i == 0) {
            mod *= i;
            p /= i;
        }
        int a = calc(n, l, r);
        CRT(a, mod);
    }
    if (p > 1) {
        ss = mod = p;
        int a = calc(n, l, r);
        CRT(a, mod);
    }
    printf("%d\n", A);
    return 0;
}