Loj#2541.「PKUWC 2018」猎人杀
题意
有 $n$ 个人,每轮杀死一个。
每个人有正整数权值 $w_i$,一轮里某个人被杀的概率为他的权值除以当前剩余的人的权值和。
$\sum\limits_{i=1}^{n}w_i \le 100000$
题解
将每个人变为 $w_i$ 个点,所有点随机排列。
从前到后选,选到一个点就把同类型全部删去。
合法的方案数就是第一个人的第一个点前每种点都至少有一个。
显然可以 $dp[i][j]$ 表示第一个人的第一个点前有 $i$ 个点,后有 $j$ 个点。
枚举往前加 $x$ 个,往后加 $y$ 个,转移就是 $dp[i+x][j+y]+=dp[i][j] \times \dbinom{i+x}{x} \times \dbinom{j+y}{y}$。
这显然是一个卷积形式,相当于 $n$ 个多项式的乘积。
所以分治 NTT 即可。
时间复杂度 $O(n \log^2n)$。
代码
#include <bits/stdc++.h>
#define gc getchar()
#define root 1,1,n
#define lc cur<<1
#define rc lc|1
#define lson lc,l,mid
#define rson rc,mid+1,r
using namespace std;
typedef long long ll;
const int mod=998244353;
const int N=400009;
int n,a[N],sg[N],*beg[N],len[N],jc[N],jc_inv[N],inv[N],cnt;
int read()
{
int x=1;
char ch;
while (ch=gc,ch<'0'||ch>'9') if (ch=='-') x=-1;
int s=ch-'0';
while (ch=gc,ch>='0'&&ch<='9') s=s*10+ch-'0';
return s*x;
}
int ksm(int x,int y,int ret=1)
{
for (;y;y>>=1,x=(ll)x*x%mod)
if (y&1) ret=(ll)ret*x%mod;
return ret;
}
int C(int n,int m)
{
if (n<m) return 0;
return (ll)jc[n]*jc_inv[m]%mod*jc_inv[n-m]%mod;
}
int lim,w[N],rev[N];
void init(int n)
{
int k=0;
lim=1;
while (lim<=n) k++,lim<<=1;
int G=ksm(3,(mod-1)/lim);
w[0]=w[lim]=1;
for (int i=1;i<lim;i++) w[i]=(ll)w[i-1]*G%mod;
for (int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
}
void dft(int *a)
{
for (int i=0;i<lim;i++)
if (rev[i]<i) swap(a[i],a[rev[i]]);
for (int i=2;i<=lim;i<<=1)
for (int j=0;j<lim;j+=i)
for (int k=0;k<i>>1;k++)
{
int x=a[j+k],y=(ll)a[j+k+(i>>1)]*w[lim/i*k]%mod;
a[j+k+(i>>1)]=(x-y<0?x-y+mod:x-y);
a[j+k]=(x+y>=mod?x+y-mod:x+y);
}
}
int A[N],B[N];
void ntt(int *a,int n,int *b,int m)
{
init(n+m);
memset(A,0,sizeof(int)*lim),memset(B,0,sizeof(int)*lim);
for (int i=0;i<=n;i++) A[i]=a[i];
for (int i=0;i<=m;i++) B[i]=b[i];
dft(A),dft(B);
for (int i=0;i<lim;i++) A[i]=(ll)A[i]*B[i]%mod;
dft(A),reverse(A+1,A+lim);
for (int i=0,t=ksm(lim,mod-2);i<=n+m;i++) a[i]=(ll)A[i]*t%mod;
}
void solve(int cur,int l,int r)
{
if (l==r)
{
beg[cur]=sg+cnt;
if (l==1)
{
sg[cnt++]=(ll)jc_inv[a[1]-1];
len[cur]=0;
}
else
{
sg[cnt++]=0;
for (int i=1;i<=a[l];i++)
sg[cnt++]=(ll)jc_inv[i]*jc_inv[a[l]-i]%mod;
len[cur]=a[l];
}
return;
}
int mid=(l+r>>1);
solve(lson),solve(rson);
ntt(beg[lc],len[lc],beg[rc],len[rc]);
beg[cur]=beg[lc],len[cur]=len[lc]+len[rc];
}
int sum=0;
int main()
{
jc[0]=1;
for (int i=1;i<N;i++) jc[i]=(ll)jc[i-1]*i%mod;
inv[1]=1;
for (int i=2;i<N;i++) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
jc_inv[0]=1;
for (int i=1;i<N;i++) jc_inv[i]=(ll)jc_inv[i-1]*inv[i]%mod;
n=read();
for (int i=1;i<=n;i++) a[i]=read(),sum+=a[i];
solve(root);
int ret=0;
for (int i=0;i<=len[1];i++)
ret=(ret+(ll)sg[i]*jc[i]%mod*jc[sum-1-i]%mod)%mod;
for (int i=1,S=0;i<=n;i++)
ret=(ll)ret*ksm(C(S+a[i],a[i]),mod-2)%mod,S+=a[i];
printf("%d\n",ret);
return 0;
}
No Comments