引入:
我们先来看一道例题:
给定序列 \(\{g_1,\dots g_{n-1}\}\), 已知 \(f_0=1\) ,\(f_n=\sum_{i=1}^{n} f_{n-i}\times g_i\), 求序列 \(\{f_i\}\), 对 \(998244353\) 取模
考虑最朴素的做法,\(O(n^2)\), 在 \(n\) 比较小的情况下是可以的
考虑 \(f_i\) 的计算显然可以化成卷积的形式, 但 \(f_i\) 的计算依赖于之前的 \(f\), 直接 \(NTT\), 退化成 \(O(n^2 \log n)\)
那么在 \(n\) 比较大的情况下, 我们要怎么来做这个东西呢?
正题:
在引入, 我们得到一个 \(O(n^2 \log n)\) 的做法, 接下来, 考虑如何优化这个做法
我们考虑分治, 现在我们要计算 \([l,r]\) 的 \(f\), 把它分成 \([l,mid]\) 和 \([mid+1,r]\), 现在假设我们已经算出了 \([l,mid]\)
考虑计算 \([l,mid]\) 的 \(f\) 对 \([mid+1,r]\) 的 \(f\) 的贡献, 对于一个 \(mid<x \le r\), 设 \(w[x]\) 为 \([l,mid]\) 对 \(x\) 的贡献
\[ w[x]=\sum_{i=l}^{mid} f_i \times g_{mid-i}\w[x]=\sum_{i=l}^{x} f_i \times g_{x-i}\\]
我们可以直接补到 \(x\), 因为大于 \(mid\) 的部分 \(f\) 为 \(0\), 可以发现,\(w[x]\) 的计算显然可以写成卷积的形式
在这里, 我们令 \(a[i]=f_{i+l}\), 令 \(b[i]=g_{i+1}\), 那么,\(w[x]\) 可以写成这样
\[ w[x]=\sum_{i=0}^{x-l-1} a[i]\times b[x-l-1-i] \]
则我们可以一次 \(NTT\) 直接算出这一部分的贡献, 然后继续分治即可, 时间复杂度 \(O(n \log ^2 n)\)
- Code:
- #include<bits/stdc++.h>
- #define int long long
- using namespace std;
- const int mod=998244353;
- const int N=2e5+11;
- int n,g[N],f[N],A[N],B[N],p[N];
- int read(){
- int x=0,f=1;char ch=getchar();
- while(!isdigit(ch)){if(ch=='-')f=-f;ch=getchar();}
- while(isdigit(ch)){x=x*10+ch-48;ch=getchar();}
- return x*f;
- }
- int qpow(int x,int y){
- int re=1;
- while(y>0){
- if(y&1) re=re*x%mod;
- y>>=1;x=x*x%mod;
- }return re;
- }
- void NTT(int *a,int flag,int len){
- for(int i=0;i<len;i++)
- if(i<p[i]) swap(a[i],a[p[i]]);
- for(int l=2;l<=len;l<<=1){
- int wn=qpow(3,(mod-1)/l);
- if(flag==-1) wn=qpow(wn,mod-2);
- for(int st=0;st<len;st+=l){
- int w=1;
- for(int u=st;u<st+(l>>1);u++,w=w*wn%mod){
- int x=a[u],y=w*a[u+(l>>1)]%mod;
- a[u]=(x+y)%mod;a[u+(l>>1)]=(x+mod-y)%mod;
- }
- }
- }
- }
- void Transform(int len){
- NTT(A,1,len);NTT(B,1,len);
- for(int i=0;i<=len;i++) A[i]=A[i]*B[i]%mod;
- NTT(A,-1,len);int inv=qpow(len,mod-2);
- for(int i=0;i<len;i++) A[i]=A[i]*inv%mod;
- }
- void DivideT(int l,int r){
- if(l==r) return ;
- int mid=l+r>>1;
- DivideT(l,mid);
- int len=1,tim=0,sz=r-l-1;
- while(len<=sz) len<<=1,++tim;
- for(int i=0;i<len;i++)
- p[i]=(p[i>>1]>>1)|((i&1)<<(tim-1));
- for(int i=0;i<len;i++) A[i]=B[i]=0;
- for(int i=0;i<=mid-l;i++) A[i]=f[i+l];
- for(int i=0;i<=r-l-1;i++) B[i]=g[i+1];
- Transform(len);
- for(int i=mid+1;i<=r;i++) f[i]=(f[i]+A[i-l-1])%mod;
- DivideT(mid+1,r);
- }
- signed main(){
- n=read();f[0]=1;
- for(int i=1;i<n;i++) g[i]=read();
- DivideT(0,n-1);
- for(int i=0;i<n;i++) printf("%lld",f[i]);
- return 0;
- }
来源: http://www.bubuko.com/infodetail-3416870.html