题目链接:
[UOJ86]mx 的组合数
题目大意: 给出四个数 $p,n,l,r$, 对于 $\forall 0\le a\le p-1$, 求 $l\le x\le r,C_{x}^{n}\%p=a$ 的 $x$ 的数量.$p<=3000$ 且保证 $p$ 是质数,$n,l,r<=10^30$.
对于 $10\%$ 的数据, 可以直接杨辉三角推.
对于 $20\%$ 的数据, 因为 $n$ 是确定的, 可以递推出 $C_{x+1}^{n}=C_{x}^{n}*\frac{x+1}{x+1-n}$.
对于另外 $20\%$ 的数据, 可以枚举 $x$ 然后用 $lucas$ 定理求.
对于另外 $30\%$ 的数据, 可以想到将问题转化成小于等于 $r$ 的个数 $-$ 小于等于 $l-1$ 的个数. 由 $lucas$ 定理可知,$C_{x}^{n}\ mod\ p=\prod C_{b_{i}}^{a_{i}}\ mod\ p$, 其中 $a_{i},b_{i}$ 分别为 $n,x$ 在 $p$ 进制下的第 $i$ 位. 那么我们就可以用数位 $DP$ 求,$f[i][j]$ 代表从最低为开始的前 $i$ 位, 每一位的值都不大于 $b_{i}$ 且 $\%p=j$ 的方案数;$g[i][j]$ 代表从最低为开始的前 $i$ 位, 每一位的值任意且 $\%p=j$ 的方案数. 设枚举第 $i+1$ 位为 $x$,$C_{x}^{a_{i+1}}=k$. 那么可以得到 $DP$ 转移方程 $g[i+1][jk\ mod\ p]+=g[i][j]$, 若 $x<b_{i+1}$, 则 $f[i+1][jk\ mod\ p]+=g[i][j]$, 若 $x=b_{i+1}$, 则 $f[i+1][jk\ mod\ p]+=f[i][j]$. 时间复杂度为 $O(p^2log_{p})$.
对于 $100\%$ 的数据, 我们考虑优化上述 $DP$, 我们拿其中第一个转移方程来说 (后两个同理), 我们设 $h[k]=\sum\limits_{x=0}^{p-1}[C_{x}^{a_{i+1}}==k]$. 可以发现转移可以看成是 $G[j*k\ mod\ p]=\sum\limits_{j=0}^{p-1}g[j]\sum\limits_{k=0}^{p-1}h[k]$, 这和卷积式子很像, 但他是乘法卷积, 我们想办法将它变成加法卷积: 因为 $p$ 是质数, 那么 $p$ 一定有原根 (设为 $g$), 也就是说对于任意 $j$, 其中 $1\le j\le p-1$ 都有指标. 我们设它的指标为 $ind(j)$, 那么 $j*k\ mod\ p$ 就能转化为 $g^{(ind(j)+ind(k))\ mod\ (p-1)}\ mod\ p$. 这样我们就能用 $FFT$ 或 $NTT$ 来加速 $DP$ 了, 但注意到 $0$ 没有指标, 我们在转移时先忽略 $0$, 在最后输出答案时用总个数减掉其他答案就是 $\%p=0$ 的个数了. 注意原根从 $1$ 开始枚举. 至于 $10^{30}$ 可以用 $\_\_int128$ 存. 时间复杂度为 $O(plog_{p}^2)$.
两种写法, 读者自选.
- #include<set>
- #include<map>
- #include<queue>
- #include<stack>
- #include<cmath>
- #include<cstdio>
- #include<vector>
- #include<bitset>
- #include<cstring>
- #include<iostream>
- #include<algorithm>
- #define ll long long
- typedef __int128 int128;
- #define MOD 998244353
- using namespace std;
- int p;
- int128 l,r,n;
- int pr[10];
- int cnt;
- int G;
- int mx;
- ll sum;
- int ind[30010];
- ll f[100000];
- ll g[100000];
- ll h[100000];
- int a[200];
- int b[200];
- ll ans[30010];
- int c[200][30010];
- int mask=1;
- ll s[100000];
- char *p1,*p2,buf[100000];
- #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
- int read_()
- {
- int x=0;
- char c=nc();
- while(c<48)
- {
- c=nc();
- }
- while(c>47)
- {
- x=(((x<<2)+x)<<1)+(c^48),c=nc();
- }
- return x;
- }
- int128 read()
- {
- int128 x=0;
- char c=nc();
- while(c<48)
- {
- c=nc();
- }
- while(c>47)
- {
- x=(((x<<2)+x)<<1)+(c^48),c=nc();
- }
- return x;
- }
- ll quick(int x,int y,int mod)
- {
- ll res=1ll;
- while(y)
- {
- if(y&1)
- {
- res=res*x%mod;
- }
- y>>=1;
- x=1ll*x*x%mod;
- }
- return res;
- }
- void NTT(ll *a,int len,int miku)
- {
- for(int k=0,i=0;i<len;i++)
- {
- if(i>k)
- {
- swap(a[i],a[k]);
- }
- for(int j=len>>1;(k^=j)<j;j>>=1);
- }
- for(int k=2;k<=len;k<<=1)
- {
- int t=k>>1;
- int x=quick(3,(MOD-1)/k,MOD);
- if(miku==-1)
- {
- x=quick(x,MOD-2,MOD);
- }
- for(int i=0;i<len;i+=k)
- {
- ll w=1;
- for(int j=i;j<i+t;j++)
- {
- ll tmp=a[j+t]*w%MOD;
- a[j+t]=(a[j]-tmp+MOD)%MOD;
- a[j]=(a[j]+tmp)%MOD;
- w=w*x%MOD;
- }
- }
- }
- if(miku==-1)
- {
- for(int i=0,t=quick(len,MOD-2,MOD);i<len;i++)
- {
- a[i]=a[i]*t%MOD;
- }
- }
- }
- void solve(int128 num)
- {
- memset(f,0,sizeof(f));
- memset(g,0,sizeof(g));
- memset(h,0,sizeof(h));
- memset(a,0,sizeof(a));
- int res=0;
- for(int i=1;num;i++)
- {
- a[i]=num%p;
- num/=p;
- res=max(res,i);
- }
- mx=max(res,mx);
- g[0]=f[0]=1ll;
- for(int k=1;k<=mx;k++)
- {
- memset(h,0,sizeof(h));
- memset(s,0,sizeof(s));
- NTT(g,mask,1);
- NTT(f,mask,1);
- if(a[k]>=b[k])
- {
- h[ind[c[k][a[k]]]]++;
- NTT(h,mask,1);
- for(int i=0;i<mask;i++)
- {
- s[i]+=1ll*h[i]*f[i]%MOD;
- s[i]%=MOD;
- }
- NTT(h,mask,-1);
- h[ind[c[k][a[k]]]]--;
- }
- for(int i=b[k];i<a[k];i++)
- {
- h[ind[c[k][i]]]++;
- }
- NTT(h,mask,1);
- for(int i=0;i<mask;i++)
- {
- s[i]+=1ll*h[i]*g[i]%MOD;
- s[i]%=MOD;
- }
- NTT(h,mask,-1);
- NTT(s,mask,-1);
- memset(f,0,sizeof(f));
- for(int i=0;i<mask;i++)
- {
- f[i%(p-1)]+=s[i];
- f[i%(p-1)]%=MOD;
- }
- for(int i=max(b[k],a[k]);i<p;i++)
- {
- h[ind[c[k][i]]]++;
- }
- NTT(h,mask,1);
- for(int i=0;i<mask;i++)
- {
- s[i]=1ll*h[i]*g[i]%MOD;
- }
- NTT(s,mask,-1);
- memset(g,0,sizeof(g));
- for(int i=0;i<mask;i++)
- {
- g[i%(p-1)]+=s[i];
- g[i%(p-1)]%=MOD;
- }
- }
- }
- int main()
- {
- p=read_(),n=read(),l=read(),r=read();
- l--;
- int s=p-1;
- while(mask<(p<<1))
- {
- mask<<=1;
- }
- for(int i=2;i*i<=s;i++)
- {
- if(s%i==0)
- {
- pr[++cnt]=i;
- while(s%i==0)
- {
- s/=i;
- }
- }
- }
- if(s!=1)
- {
- pr[++cnt]=s;
- }
- for(int i=1;i<p;i++)
- {
- bool flag=true;
- for(int j=1;j<=cnt;j++)
- {
- if(quick(i,(p-1)/pr[j],p)==1)
- {
- flag=false;
- break;
- }
- }
- if(flag)
- {
- G=i;
- break;
- }
- }
- sum=1ll;
- for(int i=0;i<p-1;i++)
- {
- ind[sum]=i;
- sum*=G,sum%=p;
- }
- int128 N=n;
- for(int i=1;N;i++)
- {
- b[i]=N%p;
- N/=p;
- mx=max(mx,i);
- }
- for(int i=1;i<=mx;i++)
- {
- for(int j=0;j<b[i];j++)
- {
- c[i][j]=0;
- }
- sum=1ll;
- for(int j=b[i];j<p;j++)
- {
- c[i][j]=sum;
- sum*=(j+1),sum%=p;
- sum*=quick(j+1-b[i],p-2,p),sum%=p;
- }
- }
- solve(l);
- for(int i=0;i<p-1;i++)
- {
- ans[quick(G,i,p)]-=f[i];
- }
- for(int i=1;i<=p-1;i++)
- {
- ans[i]=(ans[i]%MOD+MOD)%MOD;
- }
- solve(r);
- for(int i=0;i<p-1;i++)
- {
- ans[quick(G,i,p)]+=f[i];
- }
- for(int i=1;i<=p-1;i++)
- {
- ans[i]%=MOD;
- }
- ans[0]=(r-l)%MOD;
- for(int i=1;i<p;i++)
- {
- ans[0]-=ans[i];
- ans[0]=(ans[0]%MOD+MOD)%MOD;
- }
- for(int i=0;i<p;i++)
- {
- printf("%lld\n",ans[i]);
- }
- }
- #include<set>
- #include<map>
- #include<queue>
- #include<stack>
- #include<cmath>
- #include<cstdio>
- #include<vector>
- #include<bitset>
- #include<cstring>
- #include<iostream>
- #include<algorithm>
- #define ll long long
- typedef __int128 int128;
- #define MOD 998244353
- using namespace std;
- int p;
- int128 l,r,n;
- int pr[10];
- int cnt;
- int G;
- int mx;
- ll sum;
- int ind[30010];
- ll f[100000];
- ll g[100000];
- ll A[100000];
- ll B[100000];
- ll C[100000];
- int a[200];
- int b[200];
- ll ans[30010];
- int c[200][30010];
- int mask=1;
- int s[100000];
- int pw[300010];
- int fac[300010];
- int inv[300010];
- char *p1,*p2,buf[100000];
- #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
- int read_()
- {
- int x=0;
- char c=nc();
- while(c<48)
- {
- c=nc();
- }
- while(c>47)
- {
- x=(((x<<2)+x)<<1)+(c^48),c=nc();
- }
- return x;
- }
- int128 read()
- {
- int128 x=0;
- char c=nc();
- while(c<48)
- {
- c=nc();
- }
- while(c>47)
- {
- x=(((x<<2)+x)<<1)+(c^48),c=nc();
- }
- return x;
- }
- ll quick(int x,int y,int mod)
- {
- ll res=1ll;
- while(y)
- {
- if(y&1)
- {
- res=res*x%mod;
- }
- y>>=1;
- x=1ll*x*x%mod;
- }
- return res;
- }
- void NTT(ll *a,int len,int miku)
- {
- for(int k=0,i=0;i<len;i++)
- {
- if(i>k)
- {
- swap(a[i],a[k]);
- }
- for(int j=len>>1;(k^=j)<j;j>>=1);
- }
- for(int k=2;k<=len;k<<=1)
- {
- int t=k>>1;
- int x=quick(3,(MOD-1)/k,MOD);
- if(miku==-1)
- {
- x=quick(x,MOD-2,MOD);
- }
- for(int i=0;i<len;i+=k)
- {
- ll w=1;
- for(int j=i;j<i+t;j++)
- {
- ll tmp=a[j+t]*w%MOD;
- a[j+t]=(a[j]-tmp+MOD)%MOD;
- a[j]=(a[j]+tmp)%MOD;
- w=w*x%MOD;
- }
- }
- }
- if(miku==-1)
- {
- for(int i=0,t=quick(len,MOD-2,MOD);i<len;i++)
- {
- a[i]=a[i]*t%MOD;
- }
- }
- }
- void solve(int128 num)
- {
- memset(f,0,sizeof(f));
- memset(g,0,sizeof(g));
- memset(a,0,sizeof(a));
- int res=0;
- for(int i=1;num;i++)
- {
- a[i]=num%p;
- num/=p;
- res=max(res,i);
- }
- mx=max(res,mx);
- g[1]=f[1]=1ll;
- for(int k=1;k<=mx;k++)
- {
- memset(A,0,sizeof(A));
- memset(B,0,sizeof(B));
- for(int i=b[k];i<p;i++)
- {
- if(c[k][i])
- {
- A[ind[c[k][i]]]++;
- }
- }
- for(int i=1;i<p;i++)
- {
- B[ind[i]]+=g[i];
- B[ind[i]]%=MOD;
- }
- NTT(A,mask,1);
- NTT(B,mask,1);
- for(int i=0;i<mask;i++)
- {
- C[i]=A[i]*B[i]%MOD;
- }
- NTT(C,mask,-1);
- memset(g,0,sizeof(g));
- for(int i=0;i<mask;i++)
- {
- (g[quick(G,i%(p-1),p)]+=C[i])%=MOD;
- }
- memset(A,0,sizeof(A));
- for(int i=b[k];i<a[k];i++)
- {
- if(c[k][i])
- {
- A[ind[c[k][i]]]++;
- }
- }
- NTT(A,mask,1);
- for(int i=0;i<mask;i++)
- {
- C[i]=A[i]*B[i]%MOD;
- }
- NTT(C,mask,-1);
- memset(s,0,sizeof(s));
- for(int i=0;i<mask;i++)
- {
- (s[quick(G,i%(p-1),p)]+=C[i])%=MOD;
- }
- if(c[k][a[k]])
- {
- for(int i=1;i<p;i++)
- {
- (s[c[k][a[k]]*i%p]+=f[i])%=MOD;;
- }
- }
- for(int i=1;i<p;i++)
- {
- f[i]=s[i];
- }
- }
- }
- int get_ori(int p)
- {
- int s=p-1;
- for(int i=2;i*i<=s;i++)
- {
- if(s%i==0)
- {
- pr[++cnt]=i;
- while(s%i==0)
- {
- s/=i;
- }
- }
- }
- if(s!=1)
- {
- pr[++cnt]=s;
- }
- for(int i=1;i<p;i++)
- {
- bool flag=true;
- for(int j=1;j<=cnt;j++)
- {
- if(quick(i,(p-1)/pr[j],p)==1)
- {
- flag=false;
- break;
- }
- }
- if(flag)
- {
- return i;
- break;
- }
- }
- }
- int main()
- {
- p=read_(),n=read(),l=read(),r=read();
- while(mask<(p<<1))
- {
- mask<<=1;
- }
- G=get_ori(p);
- pw[0]=1ll;
- for(int i=1;i<p;i++)
- {
- pw[i]=pw[i-1]*G%p;
- }
- sum=1ll;
- for(int i=0;i<p-1;i++)
- {
- ind[sum]=i;
- sum*=G,sum%=p;
- }
- int128 N=n;
- for(int i=1;N;i++)
- {
- b[i]=N%p;
- N/=p;
- mx=max(mx,i);
- }
- fac[0]=inv[0]=1ll;
- for(int i=1;i<p;i++)
- {
- fac[i]=fac[i-1]*i%p;
- }
- inv[p-1]=quick(fac[p-1],p-2,p);
- for(int i=p-2;i>=1;i--)
- {
- inv[i]=inv[i+1]*(i+1)%p;
- }
- for(int i=1;i<=120;i++)
- {
- for(int j=b[i];j<p;j++)
- {
- c[i][j]=fac[j]*inv[j-b[i]]%p*inv[b[i]]%p;
- }
- }
- solve(r);
- for(int i=1;i<p;i++)
- {
- ans[i]=f[i];
- }
- solve(l-1);
- for(int i=1;i<p;i++)
- {
- ans[i]=((ans[i]-f[i])%MOD+MOD)%MOD;
- }
- ans[0]=(r-l+1)%MOD;
- for(int i=1;i<p;i++)
- {
- ans[0]=((ans[0]-ans[i])%MOD+MOD)%MOD;
- }
- for(int i=0;i<p;i++)
- {
- printf("%lld\n",ans[i]);
- }
- }
[UOJ86]mx 的组合数 --NTT + 数位 DP + 原根与指标 + 卢卡斯定理
来源: http://www.bubuko.com/infodetail-2973102.html