已知多项式 $A(x)$, 若存在 $A(x)B(x)\equiv 1\pmod{x^n}$
则称 $B(x)$ 为 $A(x)$ 在模 $x^n$ 下的逆元, 记做 $A^{-1}(x)$
具体的来说的话, 就是两个多项式 $A,B$ 相乘模 $x^n$ 之后, 所有次数大于等于 $n$ 的项都没了, 那么只有在剩下的项相乘之后未知数项全被消掉只留下一个常数项 $1$ 时,$B$ 才是 $A$ 的逆元
然后为什么要有模 $x^n$ 的限制呢? 因为没有这个限制的话,$B$ 可能有无穷多项
然后我们考虑如何计算 $B(x)$
当 $n=1$ 的时候,$A(x)\equiv c\pmod{x}$, 其中 $c$ 为常数项, 那么 $A^{-1}(x)$ 就是 $c^{-1}$
当 $n>1$ 时 $$B(x)A(x)\equiv 1\pmod{x^n}$$
设 $B'(x)$ 是模 $x^{\left\lceil\frac{n}{2}\right\rceil}$ 时的逆元, 即 $$B'(x)A(x)\equiv 1\pmod{x^{\left\lceil\frac{n}{2}\right\rceil}}$$
首先, 可以肯定 $$B(x)A(x)\equiv 1\pmod{x^{\left\lceil\frac{n}{2}\right\rceil}}$$
那么上下两个式子相减可得 $$B(x)-B'(x)\equiv 0\pmod{{x^{\left\lceil\frac{n}{2}\right\rceil}}}$$
然后两边平方 $$B^2(x)+2B'(x)B(x)+B'^2(x)\equiv 0\pmod{{x^n}}$$
为什么上面模数变成 $x^n$ 呢? 我们考虑如果一个多项式在 $\pmod{x^n}$ 的情况下为 $0$, 那么说明 $0$ 到 $n-1$ 项的系数也为 $0$, 它平方之后 $0$ 到 $2n-1$ 项系数 $a_i$ 为 $\sum_{j=0}^ia_ja_{i-j}$, 那么 $j$ 和 $i-j$ 中必有一个小于 $n$, 也就是说 $a_j$ 和 $a_{i-j}$ 里必有一个为 $0$, 那么 $a_i$ 也是 $0$, 所以平方之后在 $\mod{2n}$ 也为 $0$
然后在上式两边同乘 $A(x)$ 并移项可得 $$B(x)\equiv2B'(x)-A(x)B'^2(x)\pmod{x^n}$$
那么发现这个东西可以递归计算, 时间复杂度为 $O(nlogn)$
- //minamoto
- #include<iostream>
- #include<cstdio>
- #include<algorithm>
- #define swap(x,y) (x^=y,y^=x,x^=y)
- #define mul(x,y) (1ll*x*y%P)
- #define add(x,y) (x+y>=P?x+y-P:x+y)
- #define dec(x,y) (x-y<0?x-y+P:x-y)
- using namespace std;
- #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
- char buf[1<<21],*p1=buf,*p2=buf;
- inline int read(){
- #define num ch-'0'
- char ch;bool flag=0;int res;
- while(!isdigit(ch=getc()))
- (ch=='-')&&(flag=true);
- for(res=num;isdigit(ch=getc());res=res*10+num);
- (flag)&&(res=-res);
- #undef num
- return res;
- }
- char sr[1<<21],z[20];int C=-1,Z;
- inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
- inline void print(int x){
- if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x;
- while(z[++Z]=x%10+48,x/=10);
- while(sr[++C]=z[Z],--Z);sr[++C]=' ';
- }
- const int N=(1<<21)+5,P=998244353,G=3,Gi=332748118;
- inline int ksm(int a,int b){
- int res=1;
- while(b){
- if(b&1) res=mul(res,a);
- a=mul(a,a),b>>=1;
- }
- return res;
- }
- int n,r[N],X[N],Y[N],A[N],B[N],O[N];
- void NTT(int *A,int type,int len){
- int limit=1,l=0;
- while(limit<len) limit<<=1,++l;
- for(int i=0;i<limit;++i)
- r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
- for(int i=0;i<limit;++i)
- if(i<r[i]) swap(A[i],A[r[i]]);
- for(int mid=1;mid<limit;mid<<=1){
- int R=mid<<1,Wn=ksm(G,(P-1)/R);O[0]=1;
- for(int j=1;j<mid;++j) O[j]=mul(O[j-1],Wn);
- for(int j=0;j<limit;j+=R){
- for(int k=0;k<mid;++k){
- int x=A[j+k],y=mul(O[k],A[j+k+mid]);
- A[j+k]=add(x,y),A[j+k+mid]=dec(x,y);
- }
- }
- }
- if(type==-1){
- // 这里这么写是因为如果要点值转系数直接 reverse 再除以 n(也就是乘个逆元) 就好了
- reverse(A+1,A+limit);
- for(int i=0,inv=ksm(len,P-2);i<limit;++i)
- A[i]=mul(A[i],inv);
- }
- }
- void work(int *a,int *b,int len){
- if(len==1) return (void)(b[0]=ksm(a[0],P-2));
- work(a,b,len>>1);
- for(int i=0;i<len;++i) A[i]=a[i],B[i]=b[i];
- NTT(A,1,len<<1),NTT(B,1,len<<1);
- for(int i=0;i<(len<<1);++i)
- A[i]=mul(mul(A[i],B[i]),B[i]);
- NTT(A,-1,len<<1);
- for(int i=0;i<len;++i) b[i]=(1ll*(b[i]<<1)%P+P-A[i])%P;
- }
- int main(){
- // freopen("testdata.in","r",stdin);
- n=read();
- for(int i=0;i<n;++i) X[i]=(read()+P)%P;
- int len;for(len=1;len<n;len<<=1);
- work(X,Y,len);
- for(int i=0;i<n;++i) print(Y[i]);
- Ot();
- return 0;
- }
来源: http://www.bubuko.com/infodetail-2794898.html