Algorithm: 多项式乘法 Polynomial Multiplication: 快速傅里叶变换 FFT / 快速数论变换 NTT

2019/04/10 10:10

###Intro:

朴素乘法

Prerequisite knowledge:

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
register int x;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,x=c&15;else k=x=0;
while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
return k?x:-x;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
#define N (2000010)
int n,m,a[N],b,c[N];
signed main(){
Rd(n),Rd(m);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m){Rd(b);Frn1(j,0,n)c[i+j]+=b*a[j];}
Frn1(i,0,n+m)wr(c[i]),Ps;
exit(0);
}


Time complexity: $O(nm)$，如果$m=O(n)$，则为$O(n^2)$

Memory complexity: $O(n)$

朴素分治乘法

P.s 这一部分讲述了FFT的分治方法，与FFT还是有区别的，如果已经理解的可以跳过

《算法导论》

Prerequisite knowledge:

$A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2),B(x)=B^{[0]}(x^2)+xB^{[1]}(x^2)$

P.s 以下的公式中，用$A$表示$A(x)$，$A^{[0]}$和$A^{[1]}$分别表示$A^{[0]}(x^2)$和$A^{[1]}(x^2)$，$B$同理

$AB=(A^{[0]}+xA^{[1]})(B^{[0]}+xB^{[1]})=A^{[0]}B^{[0]}+x(A^{[1]}B^{[0]}+A^{[0]}B^{[1]})+x^2A^{[1]}B^{[1]}$

P.s 注意合并方式：$A^{[0]}$和$A^{[1]}$分别表示$A^{[0]}(x^2)$和$A^{[1]}(x^2)$，所以是交错的，见代码

（为了省空间用了vector）

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
register int x;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,x=c&15;else k=x=0;
while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
return k?x:-x;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
typedef vector<int> Vct;
int n,m,s;
Vct a,b,c;
void mlt(Vct&a,Vct&b,Vct&c,int n);
signed main(){
Rd(n),Rd(m),a.resize(s=1<<int(log2(max(n,m))+1)),b.resize(s),c.resize(s<<1);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m)Rd(b[i]);
mlt(a,b,c,s);
Frn1(i,0,n+m)wr(c[i]),Ps;
exit(0);
}
void mlt(Vct&a,Vct&b,Vct&c,int n){
int n2(n>>1);
Vct a0(n2),a1(n2),b0(n2),b1(n2),ab0(n),ab1(n),abm(n);
if(n==1){c[0]=a[0]*b[0];return;}
Frn0(i,0,n2)a0[i]=a[i<<1],a1[i]=a[i<<1|1],b0[i]=b[i<<1],b1[i]=b[i<<1|1];
mlt(a0,b0,ab0,n2),mlt(a1,b1,ab1,n2);
Frn0(i,0,n)c[i<<1]=ab0[i]+(i?ab1[i-1]:0);
Frn0(i,0,n-1)c[i<<1|1]=abm[i];
}


$T(n)=4T(n/2)+f(n)$，其中$f(n)=O(n)$（就是$n$位加法的时间）

分治乘法

**先来一个小插曲：**如何只做$3$次乘法，求出线性多项式$ax+b$与$cx+d$的乘积

$AB=(A^{[0]}+xA^{[1]})(B^{[0]}+xB^{[1]})=A^{[0]}B^{[0]}+x(A^{[1]}B^{[0]}+A^{[0]}B^{[1]})+x^2A^{[1]}B^{[1]}$

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
register int x;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,x=c&15;else k=x=0;
while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
return k?x:-x;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
typedef vector<int> Vct;
int n,m,s;
Vct a,b,c;
void mns(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]-b[i];}
void mlt(Vct&a,Vct&b,Vct&c);
signed main(){
Rd(n),Rd(m),a.resize(s=1<<int(log2(max(n,m))+1)),b.resize(s),c.resize(s<<1);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m)Rd(b[i]);
mlt(a,b,c);
Frn1(i,0,n+m)wr(c[i]),Ps;
exit(0);
}
void mlt(Vct&a,Vct&b,Vct&c){
int n(a.size()),n2(a.size()>>1);
Vct a0(n2),a1(n2),b0(n2),b1(n2),ab0(n),ab1(n),abm(n);
if(n==1){c[0]=a[0]*b[0];return;}
Frn0(i,0,n2)a0[i]=a[i<<1],a1[i]=a[i<<1|1],b0[i]=b[i<<1],b1[i]=b[i<<1|1];
mlt(a0,b0,ab0),mlt(a1,b1,ab1);
Frn0(i,0,n)c[i<<1]=ab0[i]+(i?ab1[i-1]:0);
Frn0(i,0,n-1)c[i<<1|1]=abm[i];
}


$T(n)=3T(n/2)+f(n)$，其中$f(n)=O(n)$

Reason 1. 分治乘法的常数因子太大

Reason 2. 打开$#5$数据一看，$n=1,m=3e6$，那么$O(n^{\log_2 3})$的分治乘法也顶不过$O(nm)$的朴素乘法啊……

快速傅里叶变换 FFT (Fast Fourier Transform)

Fairly Frightening Transform

《算法导论》

Part 1: 多项式的两种表示方式

1. 系数表达

1. 求值$O(n)$

2. 加法$O(n)$

3. 乘法朴素$O(n^2)$，优化$(n^{\log_2 3})$（即分治乘法）

P.s 当多项式$C(x)=A(x)B(x)$时，$\pmb{c}$被称为$\pmb{a}$与$\pmb{b}$的卷积(convolution)，记为$\pmb{c}=\pmb{a}\bigotimes\pmb{b}$

2. 点值表达

$\left[\begin{matrix}1&x_0&x_0^2&\cdots&x_0^{n-1}\1&x_1&x_1^2&\cdots&x_1^{n-1}\\vdots&\vdots&\vdots&\ddots&\vdots\1&x_{n-1}&x_{n-1}^2&\cdots&x_{n-1}^{n-1}\end{matrix} \right]\left[\begin{matrix}a_0\a_1\\vdots\a_{n-1}\end{matrix} \right]=\left[\begin{matrix}y_0\y_1\\vdots\y_{n-1}\end{matrix} \right]$

1. 加法$O(n)$（只要将各个位置的$y$值相加即可）

2. 乘法$O(n)$（同理）

###Part 2: 单位复数根及其性质

$n$次单位复数根是满足$\omega^n=1$的复数$\omega$，正好有$n$个，记为：

$\omega_n^k=e^{2\pi ik/n}=\cos(2\pi k/n)+i\sin(2\pi k/n)$

1. 消去引理：对任何整数$n\geqslant 0,k\geqslant 0,d>0$，有$\omega_{dn}^{dk}=\omega_n^k$

**Proof: **$\omega_{dn}^{dk}=(e^{2\pi i/dn})^{dk}=(e^{2\pi i/n})^k=\omega_n^k$

2. 折半引理：对任何偶数$n$和整数$k$，有$(\omega_n^k)^2=(\omega_n^{k+n/2})^2=\omega_{n/2}^k$

**Proof: **$(\omega_n^k)^2=\omega_n^{2k},(\omega_n^{k+n/2})^2=\omega_n^{2k+n}=\omega_n^{2k}$，最后用消去引理，$\omega_n^{2k}=\omega_{n/2}^k$

3. 求和引理：对任何整数$n\geqslant 0$与非负整数$k:n\nmid k$，有$\sum_{j=0}^{n-1}(\omega_n^k)^j=0$

**Proof: **利用等比数列求和公式，$\sum_{j=0}^{n-1}(\omega_n^k)^j=\frac{1-(\omega_n^k)^n}{1-\omega_n^k}=\frac{1-\omega_n^{nk}}{1-\omega_n^k}=\frac{1-1}{1-\omega_n^k}=0$，为了使分母$1-\omega_n^k\neq 0$，必须满足$\omega_n^k\neq 1\implies n\nmid k$

Part 3: 离散傅里叶变换 DFT (Discrete Fourier Transform)

DFT就是将次数界为$n$的多项式$A(x)$在**$n$次单位复数根求值**的过程

$V_n=V(\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1})=\left[\begin{matrix}1&1&1&1&\cdots&1\1&\omega_n&\omega_n^2&\omega_n^3&\cdots&\omega_n^{n-1}\1&\omega_n^2&\omega_n^4&\omega_n^6&\cdots&\omega_n^{2(n-1)}\1&\omega_n^3&\omega_n^6&\omega_n^9&\cdots&\omega_n^{3(n-1)}\\vdots&\vdots&\vdots&\vdots&\ddots&\vdots\1&\omega_n^{n-1}&\omega_n^{2(n-1)}&\omega_n^{3(n-1)}&\cdots&\omega_n^{(n-1)(n-1)}\end{matrix} \right]$

Part 4: FFT

FFT利用单位根的特殊性质把DFT优化到了$O(n\log n)$

2. 合并答案

$\omega_n^{n/2}=e^{2\pi i (n/2)/n}=e^{\pi i}=-1$（根据传说中的最美公式$e^{i\pi}+1=0$）

递归边界：$n=1$，那么$w_1^0 a_0=a_0$，所以直接返回自身

$T(n)=2T(n/2)+f(n)$，其中$f(n)=O(n)$（合并答案）

Part 5: 离散傅里叶逆变换

定理：对$i,j=0,1,\cdots,n-1$，有$[V_n^{-1}]_{ij}=\omega_n^{-ij}/n$

**Proof: **证明$V_n^{-1}V_n=I_n$即可

$[V_n^{-1}V_n]{ij}=\sum{k=0}^{n-1}(\omega_n^{-ik}/n)\omega_n^{kj}=\frac{\sum_{k=0}^{n-1}\omega_n^{-ik}\omega_n^{kj}}{n}=\frac{\sum_{k=0}^{n-1}\omega_n^{(j-i)k}}{n}$

Part 6: 递归实现

STL提供了现成的complex类可供使用

**P.s **最后别忘了$/n$，而且$+0.5$为了四舍五入提高精度

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
register int u;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,u=c&15;else k=u=0;
while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
return k?u:-u;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
double const Pi(acos(-1));
typedef complex<double> Cpx;
#define N (2100000)
Cpx o,w,a[N],b[N],tmp[N],x,y;
int n,m,s;
bool iv;
void fft(Cpx*a,int n);
signed main(){
Rd(n),Rd(m),s=1<<int(log2(n+m)+1);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m)Rd(b[i]);
fft(a,s),fft(b,s);
Frn0(i,0,s)a[i]*=b[i];
iv=1,fft(a,s);
Frn1(i,0,n+m)wr(a[i].real()/s+0.5),Ps;
exit(0);
}
void fft(Cpx*a,int n){
if(n==1)return;
int n2(n>>1);
Frn0(i,0,n2)tmp[i]=a[i<<1],tmp[i+n2]=a[i<<1|1];
copy(tmp,tmp+n,a),fft(a,n2),fft(a+n2,n2);
o={cos(Pi/n2),(iv?-1:1)*sin(Pi/n2)},w=1;
Frn0(i,0,n2)x=a[i],y=w*a[i+n2],a[i]=x+y,a[i+n2]=x-y,w*=o;
}


Time complexity: $O(n\log n)$

Memory complexity: $O(n)$

Part 6: 迭代实现

设$l=\lceil\log_2(n+m+1)\rceil,s=2^l$，那么$A(x),B(x),A(x)B(x)$都是次数界为$s$的多项式

0-> 0 1 2 3 4 5 6 7
1-> 0 2 4 6|1 3 5 7
2-> 0 4|2 6|1 5|3 7
end 0|4|2|6|1|5|3|7


0-> 000 001 010 011 100 101 110 111
1-> 000 010 100 110|001 011 101 111
2-> 000 100|010 110|001 101|011 111
end 000|100|010|110|001|101|011|111


**一个较为感性的Proof: **因为是按照奇偶性分类，也就是说在第$i$层递归时判断的是该编号二进制第$i$位（从零开始），为$0$放左边，$1$放右边，而放右边的结果就是它的位置编号的二进制第$l-i-1$位是$1$

蝴蝶操作 (Butterfly Operation)

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
register int u;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,u=c&15;else k=u=0;
while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
return k?u:-u;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
double const Pi(acos(-1));
typedef complex<double> Cpx;
#define N (2100000)
Cpx a[N],b[N],o,w,x,y;
int n,m,l,s,r[N];
void fft(Cpx*a,bool iv);
signed main(){
Rd(n),Rd(m),s=1<<(l=log2(n+m)+1);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m)Rd(b[i]);
Frn0(i,0,s)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
fft(a,0),fft(b,0);
Frn0(i,0,s)a[i]*=b[i];
fft(a,1);
Frn1(i,0,n+m)wr(a[i].real()+0.5),Ps;
exit(0);
}
void fft(Cpx*a,bool iv){
Frn0(i,0,s)if(i<r[i])swap(a[i],a[r[i]]);
for(int i(2),i2(1);i<=s;i2=i,i<<=1){
o={cos(Pi/i2),(iv?-1:1)*sin(Pi/i2)};
for(int j(0);j<s;j+=i){
w=1;
Frn0(k,0,i2){
x=a[j+k],y=w*a[j+k+i2];
a[j+k]=x+y,a[j+k+i2]=x-y,w*=o;
}
}
}
if(iv)Frn0(i,0,s)a[i]/=s;
}


Time complexity: $O(n\log n)$

Memory complexity: $O(n)$

Extension: 快速数论变换 NTT (Number Theoretic Transform)

《算法导论》

Prerequisite knowledge:

FFT（必须知道的）

原根的性质

**E.g **对于$P=7$，计算所有$<P$的正整数的次幂构成的集合

1-> {1}
2-> {1,2,4}
3-> {1,2,3,4,5,6}
4-> {1,2,4}
5-> {1,2,3,4,5,6}
6-> {1,6}


**E.g **$P=7,g=3$

####单位根的代替品

离散对数定理：如果$g$是$Z_P^*$的一个原根，则$x\equiv y(\mod\phi(P))\iff g^x\equiv g^y(\mod P)$

**Proof: **设$x\equiv y(\mod\phi(P))$，则对某个整数$k$有$x=y+k\phi(P)$

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
register int u;register char c(getchar());register bool k;
while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
if(c^'-')k=1,u=c&15;else k=u=0;
while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
return k?u:-u;
}
void wr(register int a){
if(a<0)Pc('-'),a=-a;
if(a<=9)Pc(a|'0');
else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
#define P (998244353)
#define G (3)
#define Gi (332748118)
#define N (2100000)
int n,m,l,s,r[N],a[N],b[N],o,w,x,y,siv;
int fpw(int a,int p){return p?a>>1?(p&1?a:1)*fpw(a*a%P,p>>1)%P:a:1;}
void ntt(int*a,bool iv);
signed main(){
Rd(n),Rd(m),siv=fpw(s=1<<(l=log2(n+m)+1),P-2);
Frn1(i,0,n)Rd(a[i]);
Frn1(i,0,m)Rd(b[i]);
Frn0(i,0,s)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
ntt(a,0),ntt(b,0);
Frn0(i,0,s)a[i]=a[i]*b[i]%P;
ntt(a,1);
Frn1(i,0,n+m)wr(a[i]),Ps;
exit(0);
}
void ntt(int*a,bool iv){
Frn0(i,0,s)if(i<r[i])swap(a[i],a[r[i]]);
for(int i(2),i2(1);i<=s;i2=i,i<<=1){
o=fpw(iv?Gi:G,(P-1)/i);
for(int j(0);j<s;j+=i){
w=1;
Frn0(k,0,i2){
x=a[j+k],y=w*a[j+k+i2]%P;
a[j+k]=(x+y)%P,a[j+k+i2]=(x-y+P)%P,w=w*o%P;
}
}
}
if(iv)Frn0(i,0,s)a[i]=a[i]*siv%P;
}


Time complexity: $O(n\log n)$

Memory complexity: $O(n)$

0
0 收藏

0 评论
0 收藏
0