Cydiater

arc087F Squirrel Migration

给定一个长度为$n$的排列和大小为$n$的树,定义一个排列的价值为$\sum dis(i, p_i)$,求有多少个排列价值是最大的。

可以发现,这道题如果直接按距离是不好做的,转化为每条边可以经过多少次。我们发现,如果一条边的左右端点的子树大小是$size_u,size_v$,那么一条边最多可以被经过的次数是$2\times \min (size_u, size_v)$,我们考虑构造一种方案,使得每条边被经过的次数都达到最大值。

首先我们找到这棵树的重心。先考虑有两个重心的情况。这个时候两边的子树大小是相等的,我们对于两边的子树每个点,选择另一个子树的点即可。可以发现这样可以使得所有边都得到充分的利用。

接着考虑只有一个重心的情况。我们发现,对于每个点来说,并不好表示他选择了外部的哪个点,我们考虑计算每个点选择内部的点,然后容斥掉即可,可以发现,对于一个子树内选择$k$个起点和$k$个终点并且一一映射的方案是很好计算的。那么我们对于每个子树的方案依次处理,最后合并一下背包就行了。

#include <bits/stdc++.h>

using namespace std;

#define ll             long long
#define db            double
#define up(i,j,n)        for (int i = j; i <= n; i++)
#define down(i,j,n)    for (int i = j; i >= n; i--)
#define cadd(a,b)        a = add (a, b)
#define cpop(a,b)        a = pop (a, b)
#define cmul(a,b)        a = mul (a, b)
#define pr            pair<int, int>
#define fi            first
#define se            second
#define SZ(x)        (int)x.size()
#define bin(i)        (1 << (i))
#define Auto(i,node)    for (int i = LINK[node]; i; i = e[i].next)

template<typename T> inline bool cmax(T & x, T y){return y > x ? x = y, true : false;}
template<typename T> inline bool cmin(T & x, T y){return y < x ? x = y, true : false;}

const int MAXN = 5005;;
const int oo = 0x3f3f3f3f;
const int mod = 1e9 + 7;

inline int add(int a, int b){a += b; return a >= mod ? a - mod : a;}
inline int pop(int a, int b){a -= b; return a < 0 ? a + mod : a;}
inline int mul(int a, int b){return 1LL * a * b % mod;}

int N, fac[MAXN], f[MAXN], g[MAXN], a[MAXN], C[MAXN][MAXN], M = 0, ans = 0;

struct edge {
    int y, next;
}e[MAXN << 1];
int LINK[MAXN], len = 0, siz[MAXN], mxsiz[MAXN], rt[MAXN], rts;
inline void ins(int x, int y){
    e[++len].next = LINK[x]; LINK[x] = len;
    e[len].y = y; 
}
inline void Ins(int x, int y){
    ins(x, y);
    ins(y, x);
}

void DFS(int node, int fa){
    siz[node] = 1; mxsiz[node] = 0;
    Auto (i, node) if (e[i].y != fa) {
        DFS(e[i].y, node);
        siz[node] += siz[e[i].y];
        cmax(mxsiz[node], siz[e[i].y]);
    }
    cmax(mxsiz[node], N - siz[node]);
    if (mxsiz[node] < mxsiz[rt[rts]]) rt[rts = 1] = node;
    else if (mxsiz[node] == mxsiz[rt[rts]]) rt[++rts] = node;
}

void Fix(int node, int fa){
    siz[node] = 1;
    Auto (i, node) if (e[i].y != fa) {
        Fix(e[i].y, node);
        siz[node] += siz[e[i].y];
    }
}

int main(){
    scanf("%d", &N);
    C[0][0] = 1;
    up (i, 1, N) {
        C[i][0] = 1;
        up (j, 1, i) C[i][j] = add(C[i - 1][j - 1], C[i - 1][j]);
    } 
    fac[0] = 1;
    up (i, 1, N) fac[i] = mul(i, fac[i - 1]);
    up (i, 2, N) {
        int x, y;
        scanf("%d%d", &x, &y);
        Ins(x, y);
    }
    mxsiz[0] = oo;
    DFS(1, 0);
    if (rts == 2) {
        Fix(rt[1], rt[2]);
        Fix(rt[2], rt[1]);
        printf("%d\n", mul(fac[siz[rt[1]]], fac[siz[rt[2]]]));
        return 0;
    }
    Fix(rt[1], 0); 
    Auto (i, rt[1]) a[++M] = siz[e[i].y];
    f[0] = 1;
    int cur = 0;
    up (i, 1, M) {
        memset(g, 0, sizeof(int) * (cur + 1));
        up (j, 0, cur) up (k, 0, a[i]) cadd(g[j + k], mul(f[j], mul(C[a[i]][k], mul(C[a[i]][k], fac[k]))));
        cur += a[i];
        memcpy(f, g, sizeof(int) * (cur + 1));
    } 
    up (i, 0, N) {
        if (i & 1)     cpop(ans, mul(f[i], fac[N - i]));
        else        cadd(ans, mul(f[i], fac[N - i]));
    } 
    printf("%d\n", ans);
    return 0;
}