Loj#2541.「PKUWC 2018」猎人杀

Loj#2541.「PKUWC 2018」猎人杀

题意

有 $n$ 个人,每轮杀死一个。
每个人有正整数权值 $w_i$,一轮里某个人被杀的概率为他的权值除以当前剩余的人的权值和。
$\sum\limits_{i=1}^{n}w_i \le 100000$

题解

将每个人变为 $w_i$ 个点,所有点随机排列。
从前到后选,选到一个点就把同类型全部删去。
合法的方案数就是第一个人的第一个点前每种点都至少有一个。
显然可以 $dp[i][j]$ 表示第一个人的第一个点前有 $i$ 个点,后有 $j$ 个点。
枚举往前加 $x$ 个,往后加 $y$ 个,转移就是 $dp[i+x][j+y]+=dp[i][j] \times \dbinom{i+x}{x} \times \dbinom{j+y}{y}$。
这显然是一个卷积形式,相当于 $n$ 个多项式的乘积。
所以分治 NTT 即可。
时间复杂度 $O(n \log^2n)$。

代码

#include <bits/stdc++.h>
#define gc getchar()
#define root 1,1,n
#define lc cur<<1
#define rc lc|1
#define lson lc,l,mid
#define rson rc,mid+1,r
using namespace std;
typedef long long ll;
const int mod=998244353;
const int N=400009;
int n,a[N],sg[N],*beg[N],len[N],jc[N],jc_inv[N],inv[N],cnt;
int read()
{
    int x=1;
    char ch;
    while (ch=gc,ch<'0'||ch>'9') if (ch=='-') x=-1;
    int s=ch-'0';
    while (ch=gc,ch>='0'&&ch<='9') s=s*10+ch-'0';
    return s*x;
}
int ksm(int x,int y,int ret=1)
{
    for (;y;y>>=1,x=(ll)x*x%mod)
        if (y&1) ret=(ll)ret*x%mod;
    return ret;
}
int C(int n,int m)
{
    if (n<m) return 0;
    return (ll)jc[n]*jc_inv[m]%mod*jc_inv[n-m]%mod;
}
int lim,w[N],rev[N];
void init(int n)
{
    int k=0;
    lim=1;
    while (lim<=n) k++,lim<<=1;
    int G=ksm(3,(mod-1)/lim);
    w[0]=w[lim]=1;
    for (int i=1;i<lim;i++) w[i]=(ll)w[i-1]*G%mod;
    for (int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
}
void dft(int *a)
{
    for (int i=0;i<lim;i++)
        if (rev[i]<i) swap(a[i],a[rev[i]]);
    for (int i=2;i<=lim;i<<=1)
        for (int j=0;j<lim;j+=i)
            for (int k=0;k<i>>1;k++)
            {
                int x=a[j+k],y=(ll)a[j+k+(i>>1)]*w[lim/i*k]%mod;
                a[j+k+(i>>1)]=(x-y<0?x-y+mod:x-y);
                a[j+k]=(x+y>=mod?x+y-mod:x+y);
            }
}
int A[N],B[N];
void ntt(int *a,int n,int *b,int m)
{
    init(n+m);
    memset(A,0,sizeof(int)*lim),memset(B,0,sizeof(int)*lim);
    for (int i=0;i<=n;i++) A[i]=a[i];
    for (int i=0;i<=m;i++) B[i]=b[i];
    dft(A),dft(B);
    for (int i=0;i<lim;i++) A[i]=(ll)A[i]*B[i]%mod;
    dft(A),reverse(A+1,A+lim);
    for (int i=0,t=ksm(lim,mod-2);i<=n+m;i++) a[i]=(ll)A[i]*t%mod;
}
void solve(int cur,int l,int r)
{
    if (l==r)
    {
        beg[cur]=sg+cnt;
        if (l==1)
        {
            sg[cnt++]=(ll)jc_inv[a[1]-1];
            len[cur]=0;
        }
        else
        {
            sg[cnt++]=0;
            for (int i=1;i<=a[l];i++)
                sg[cnt++]=(ll)jc_inv[i]*jc_inv[a[l]-i]%mod;
            len[cur]=a[l];
        }
        return;
    }
    int mid=(l+r>>1);
    solve(lson),solve(rson);
    ntt(beg[lc],len[lc],beg[rc],len[rc]);
    beg[cur]=beg[lc],len[cur]=len[lc]+len[rc];
}
int sum=0;
int main()
{
    jc[0]=1;
    for (int i=1;i<N;i++) jc[i]=(ll)jc[i-1]*i%mod;
    inv[1]=1;
    for (int i=2;i<N;i++) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
    jc_inv[0]=1;
    for (int i=1;i<N;i++) jc_inv[i]=(ll)jc_inv[i-1]*inv[i]%mod;
    n=read();
    for (int i=1;i<=n;i++) a[i]=read(),sum+=a[i];
    solve(root);
    int ret=0;
    for (int i=0;i<=len[1];i++)
        ret=(ret+(ll)sg[i]*jc[i]%mod*jc[sum-1-i]%mod)%mod;
    for (int i=1,S=0;i<=n;i++)
        ret=(ll)ret*ksm(C(S+a[i],a[i]),mod-2)%mod,S+=a[i];
    printf("%d\n",ret);
    return 0;
}

 

点赞 2

No Comments

Add your comment