Cydiater

「HackerRank」Cards Permutation

给出一个长度为$n$的不完整的排列,其中有一些元素是未知的。把所有的排列按照字典序排序,对于排名为$i$的排列编号为$i$,如果一个排列可以删去若干个元素变成这个不完整的排列,那么我们称这个排列是合法的。你需要输出所有合法排列的编号之和。

方便起见,我们把所有的元素减去1,然后我们考虑怎么解决这个问题。很明显,我们从暴力优化的角度是不好做的,因为排列的数量太多了。我们考虑,对于一个排列来说,如果存在一个排列字典序比他小的话,那么一定存在一个$i$,满足对于$j<i$的所有$j$,$a[j] = b[j]$,而且$a[i]<b[i]$。所以我们可以尝试从那个转折点来分析。

我们枚举那个转折点,对于前$i-1$的部分,肯定满足他们都是相等的,而$i$这个位置所选的数,必须比已经存在的数小,而且不能在前面出现过,我们分类讨论一下。

  1. $a[i]\not = -1$
    假设总共有$k$个位置满足上面的数为$-1$,那么所有的方案为$k!\times a[i]$,接着考虑删掉不合法的方案,首先已经在前面出现过的一定是不合法的,即我们可以随便拿一个数据结构求一下有多少个数小于$a[i]$,不妨设为$x$,那么这一部分应该排除掉的答案就是$k!\times x$,接着考虑前面为$-1$的情况,首先在前面为$-1$的地方确定一个位置,设为$c$,然后在所有可以选的数里选择一个数,设为$d$,即$(k - 1)!\times c\times d$,把所有不合法的答案减掉后再乘上$(N - i)!$就好了。很明显这覆盖掉了所有的情况。

  2. $a[i] = -1$
    这个情况和上面的其实是比较类似的。不再说多了。

最后因为我们没有考虑到相等的情况,所以把答案再加上$k!$

#include <bits/stdc++.h>

using namespace std;

#define ll             long long
#define up(i,j,n)        for (int i = j; i <= n; i++)
#define down(i,j,n)    for (int i = j; i >= n; i--)
#define cmax(a,b)        a = ((a) > (b) ? (a) : (b))
#define cmin(a,b)        a = ((a) < (b) ? (a) : (b))
#define cadd(a,b)        a = add (a, b)
#define cpop(a,b)        a = pop (a, b)
#define cmul(a,b)        a = mul (a, b)
#define pii            pair<int, int>
#define fi            first
#define se            second
#define SZ(x)        (int)x.size()
#define Auto(i,node)    for (int i = LINK[node]; i; i = e[i].next)

const int MAXN = 3e5 + 5;
const int oo = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const int inv2 = (mod + 1) >> 1;

inline int add(int a, int b){a += b; return a >= mod ? a - mod : a;}
inline int pop(int a, int b){a -= b; return a < 0 ? a + mod : a;}
inline int mul(int a, int b){return 1LL * a * b % mod;}

int N, a[MAXN], fac[MAXN], pre[MAXN], cnt = 0, ans = 0, cur = 0, lex = 0;

namespace BIT{
    int c[MAXN];
    inline int lowbit(int i){return i & (-i);}
    inline void upd(int o, int v){
        for (int i = o + 1; i <= N; i += lowbit(i)) c[i] += v;
    }
    inline int calc(int o){
        int sum = 0;
        for (int i = o + 1; i >= 1; i -= lowbit(i)) sum += c[i];
        return sum;
    }
}using namespace BIT;

namespace solution{
    int cal2(int n){return mul(mul(n, pop(n, 1)), inv2);}
    void Prepare(){
        scanf("%d", &N);
        up (i, 1, N) scanf("%d", &a[i]);
        up (i, 1, N) {
            a[i]--;
            cnt += (a[i] == -1);
            if (a[i] >= 0) pre[a[i]] = 1;
        }
        fac[0] = 1;
        up (i, 1, N) fac[i] = mul(i, fac[i - 1]);
        up (i, 1, N - 1) pre[i] += pre[i - 1];
        lex = mul(mul(N, pop(N, 1)), inv2);
        up (i, 1, N) if (a[i] != -1) cpop(lex, a[i]);
    }
    void Solve(){
        up (i, 1, N) {
            if (a[i] != -1) {
                int sum = mul(fac[cnt] , a[i] - calc(a[i]));
                if (cnt >= 1) cpop(sum, mul(fac[cnt - 1], mul(cur, a[i] + 1 - pre[a[i]])));
                cmul(sum, fac[N - i]);
                cadd(ans, sum);
                upd(a[i], 1);
                cpop(lex, pop(N - 1 - a[i], pop(pre[N - 1], pre[a[i]])));
            }else {
                int sum = mul(lex, fac[cnt - 1]);
                if (cnt >= 2) cpop(sum, mul(fac[cnt - 2], mul(cur, cal2(cnt))));
                cmul(sum, fac[N - i]);
                cadd(ans, sum);
                cur++;
            }
        }
        printf("%d\n", add(ans, fac[cnt]));
    }
}

int main(){
    using namespace solution;
    Prepare();
    Solve();
    return 0;
}