bzoj2142 礼物

题意:小E从商店中购买了n件礼物,打算送给m个人,其中送给第i个人礼物数量为wi。请你帮忙计算出送礼物的方案数(两个方案被认为是不同的,当且仅当存在某个人在这两种方案中收到的礼物不同)。由于方案数可能会很大,你只需要输出模P后的结果。

明显答案是$C_n^{w_1}\times C_{n-w_1}^{w_2} \times \cdots$,如果要求的礼物总数超过了买的礼物,则直接输出Impossible

求解一个组合数$C_n^m$对$p_i^{k_i}$的过程可以分别求分子和分母,然后再进行运算,即,求出$n!\mod p_i^{k_i}$,$m!\mod p_i^{k_i}$,$(n-m)!\mod p_i^{k_i}$后合并。注意分子可以含有大于$k_i$个$p_i$,这样的话求出来就直接为0了,而除掉分母的时候是可能会除掉一些$p_i$的;同时分母本身也可能有含有$p_i$的约数,于是统一把$p_i$移除来,不乘进答案,接下来用分子有的$p_i$的质数减掉分母的,然后乘上去即可。不可能出现减掉后为负数的情况,因为那样意味着组合数除不尽;有可能出现减掉后仍然大于$k_i$的情况,那么这个组合数在模$p_i^{k_i}$下为$0$。

具体地,对于求$n! \mod p_i^{k_i}$可以分为几个部分求解。将每个$p_i$的倍数提出一个$p_i$来,组成$p_i^{\lfloor \frac{n}{p_i} \rfloor}$;对于剩余的部分,容易发现除掉了$p_i$的数(原来是$p_i$的倍数)组成了一个新的阶乘,即$\lfloor \frac{n}{p_i} \rfloor !$,这个我们可以递归求解。剩下的数形如$1,2,\cdots,p-1,p+1,p+2,\cdots,2p-1 \cdots$(因为$p_i$的倍数已经通过前两步瓜分到阶乘和次幂里去了),由于$a \times b \mod p=(a \mod p) \times (b \mod p)$,易得$\prod _{i=1}^{p_i^{k_i}-1}i[i \mod p \neq 0]$与$\prod _{i=p_i^{k_i}+1}^{2p_i^{k_i}-1}i[i \mod p \neq 0] $在$\mod p_i^{k_i}$下是相等的,那么只需求一个$p^k$的长度的乘积在模$p_i^{k_i}$下的结果即可,而这个阶乘里包含了$ \lfloor \frac{n}{p_i^{k_i}} \rfloor$个这个余数,快速幂即可;还有剩下来的、没有达到$p_i^{k_i}$的一段数,由于长度不会超过$p_i^{k_i}$(即$10^5$),直接计算即可。

考虑求解答案,因为模的是一个和数,所以不能直接用普通的$lucas$求解。将和数$P$分解,得到$P=\prod_{i=1}^{k}p_i^{k_i}$,我们要求的答案模$p_i^{k_i}$的结果,对于单独求组合数在模每个$p_i^{k_i}$的剩余系下的结果要相同。于是考虑对于每个组合数,单独求出对$p_i^{k_i}$的结果,然后通过中国剩余定理求解即可。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=6;
const int maxp=(int)1e5+5;
int P,n,m,w[maxn],pri[35],cnt[35],tot=0,c[35],R[35],M[35];
map<int ,int > re;
inline int maybs(int x) {return x<0?-x:x;}
void div(int x)
{
for(int i=2;i*i<=x;i++) {
if(x%i==0) {
pri[++tot]=i;re[i]=tot;
while(x%i==0) x/=i,cnt[tot]++;
}
}
if(x!=1) {
if(re.find(x)==re.end()) pri[++tot]=x,re[x]=tot;
cnt[re[x]]++;
}
}
int quickpow(int x,int k)
{
int ret=1;
while(k>0) {
if(k&1) ret=ret*x;
x=x*x; k>>=1;
}
return ret;
}
int quickpow(int x,int k,int mo)
{
int ret=1;
while(k>0) {
if(k&1) ret=1ll*ret*x%mo;
x=1ll*x*x%mo; k>>=1;
}
return ret;
}
int gcd(int a,int b) {return b==0?a:gcd(b,a%b);}
int extgcd(int a,int b,int &x,int &y)
{
if(b==0) {x=1,y=0; return a;}
else {int d=extgcd(b,a%b,y,x);y-=a/b*x;return d;}
}
int inv(int q,int p)
{
assert(gcd(p,q)==1);
int x,y;
extgcd(q,p,x,y);
x=(x%p+p)%p; if(!x) x+=p;
return x;
}
int Cnt(int x,int p) {
if(x==0) return 0;
return Cnt(x/p,p)+x/p;
}
typedef pair<int ,int > pii;
#define fir first
#define sec second
#define MP make_pair
pii Cal(int x,int p,int k)
{
if(x==1 || x==0) return MP(1,0);
int ret=1,del=0,tmp=quickpow(p,k);
int cou=Cnt(x,p);
del=cou;
pii nw=Cal(x/p,p,k);
ret=1ll*ret*nw.fir%tmp;
if(x>=tmp) {
int lev=1;
for(int i=1;i<tmp;i++) {
if(i%p==0) continue;
lev=1ll*lev*i%tmp;
}
ret=1ll*ret*quickpow(lev,x/tmp,tmp)%tmp;
}
for(int i=x;i>=1;i--) {
if(i%tmp==0) break;
if(i%p==0) continue;
ret=1ll*ret*i%tmp;
}
return MP(ret,del);
}
int getAns()
{
for(int i=1;i<=tot;i++) M[i]=1;
for(int i=1;i<=tot;i++) {
for(int j=1;j<=tot;j++)
if(i==j) continue;
else M[i]=1ll*M[i]*R[j]%P;
}
int ret=0;
for(int i=1;i<=tot;i++) ret=(ret+((1ll*c[i]*M[i]%P)*1ll*inv(M[i],R[i]))%P)%P;
return ret;
}
int main()
{
scanf("%d",&P);
div(P);
scanf("%d%d",&n,&m);
int sum=0;
for(int i=1;i<=m;i++) {scanf("%d",&w[i]);sum+=w[i];}
if(sum>n) {printf("Impossible\n");exit(0);}
int ans=1,N=n,M,res=1;
pii up,dw1,dw2;
for(int i=1;i<=m;i++) {
M=w[i];
for(int j=1;j<=tot;j++) R[j]=quickpow(pri[j],cnt[j]);
for(int j=1;j<=tot;j++) {
ans=1;
up=Cal(N,pri[j],cnt[j]);
dw1=Cal(M,pri[j],cnt[j]),dw2=Cal(N-M,pri[j],cnt[j]);
up.sec-=dw1.sec; up.sec-=dw2.sec;
assert(up.sec>=0);
if(up.sec>=cnt[j]) ans=0;
else {
ans=1ll*ans*up.fir%R[j];
ans=1ll*ans*inv(dw1.fir,R[j])%R[j]; ans=1ll*ans*inv(dw2.fir,R[j])%R[j];
ans=1ll*ans*quickpow(pri[j],up.sec,R[j])%R[j];
}
c[j]=ans;
}
res=1ll*res*getAns()%P;
N-=w[i];
}
printf("%d\n",res);
return 0;
}