FFT讲解

引入

傅里叶变换是把函数从时间域转化到频域的数学变换,FFT是其快速实现方式。本文主要讨论的是FFT在ACM竞赛中的应用。

FFT用来解决多项式相乘问题,可以把 $O(n^2)$ 的时间复杂度优化到 $O(nlg(n))$ 。

一个 $n$ 次多项式可以表示为 $F(x) = \sum_{i=0}^{n}a_{i}x^i$ ,现在再有一个多项式 $G(x)$ ,让求两个多项式相乘 $C(x)$ 的系数,有 \(c_{k} = \sum_{i+j=k}a_ib_j\) 其本质是加法卷积。两个多项式的卷积写作 \(c_k = \sum_{i \bigoplus j=k}a_ib_j\) 其中 $\bigoplus$ 代表某种运算,在多项式乘法中就代表了加法。还有位运算等形式,会在之后介绍。

根据上述理论,有多项式乘法的暴力实现代码, $O(n^2)$ :

1
2
3
4
5
void mul(int *c, int *a, int *b){
	for(int i=0; i<=n; i++) //最高次为n和m
        for(int j=0; j<=m; j++)
            c[i+j] += a[i]*b[j];
}

点值表示法

先把多项式化为函数 \(F(x) = \sum_{i=0}^{n}a_ix^i,G(x) = \sum_{i=0}^{n}b_ix^i\) 现在要求 $C(x) = \sum_{i=0}^{2n}c_ix^i$ 的系数 $c_i$ 。

我们知道用 $n+1$ 个点可以确定一个 $n$ 次函数,取 $n+1$ 点 $x_i$ 带入 $F(x)$ 和 $G(x)$ 可以得到对应的 $y$ 值,用这 $n+1$ 个点也可以唯一的表示 $F(x)$ 和 $G(x)$ 。因为 $C(x_i) = F(x_i)G(x_i)$ ,所以可以$O(n)$ 的得到 $C(x)$ 的 $n+1$ 个点,但是因为 $C(x)$ 的最高幂为 $n+m$,若想唯一的表示 $C(x)$ 需要 $2n+1$ 个点,所以可以对 $F(x)$ 和 $G(x)$ 取 $2n+1$ 个点,时间复杂度依然不变,所以接下来就要考虑如何把系数表示法转化为点值表示法。

其中把系数表示法转化为点值表示法称作 $DFT$ ,点值表示转化成系数表示称作 $IDFT$ 。而 $FFT$ 就是 $DFT$ 的加速版本。

单位根的性质

以下性质在推到中会用到,在这不做具体证明。

  • $\omega_{n}^{1}$ 的在复数系的坐标为 $(cos\dfrac{2π}{n},sin\dfrac{2π}{n})$ 即 $\omega_{n}^{1} = cos\frac{2\pi}{n} + isin\frac{\pi}{n}$。

  • $\omega_{n}^{k} = \omega_{n}^{k\%n}$
  • $\omega_{n}^{k1}\omega_{n}^{k2} = \omega_{n}^{k1+k2}$
  • $\omega_{2n}^{2k} = \omega_{n}^{k}$
  • $\omega_{n}^{k+n/2} = -\omega_{n}^{k}$ $(n是偶数)$

DFT

把 $n-1$ 次(也叫 $n$ 项)多项式拆分成奇偶两部分 \(F(x) = (a_0+a_2x^2+...+a_{n-2}x^{n-2})+(a_1x+a_3x^3+...+a_{n-1}x^{n-1})\) 这里为了保证可以让两者平均分,必须保证 $n$ 是 $2$ 的倍数。 不足的高次幂可以补 $0$ 。

设下面两个 $\frac{n}{2}$ 项函数 \(L(x) = a_0 + a_2x + ... + a_{n-2}x^{\frac{n}{2}-1} \\ R(x) = a_1 + a_3x + ... + a_{n-1}x^{\frac{n}{2}-1}\) 则 \(F(x) = L(x^2) + xR(x^2)\) 现在要代入 $n$ 次单位根 $\omega_{n}^{k}$ $(k<\frac{n}{2})$ ,即考虑 $(\omega_{n}^{k},y_k)$ $( k<\frac{n}{2} )$ 这些点: \(F(\omega_{n}^{k}) = L((\omega_{n}^{k})^2) + \omega_{n}^{k}R((\omega_{n}^{k})^2)\) 因为 $(\omega_{n}^{k})^2 = \omega_{\frac{n}{2}}^{k}$ ,所以 \(F(\omega_{n}^{k}) = L(\omega_{n/2}^{k}) + \omega_{n}^{k}R(\omega_{n/2}^{k})\) 目前只考虑了$(\omega_{n}^{k},y_k)$ $( k<\frac{n}{2} )$ 这些点,下面考虑 $(\omega_{n}^{k},y_k)$ $( k\geq \frac{n}{2} )$ 的点

把 $\omega_{n}^{k+n/2}$ $(k<\frac{n}{2})$代入: \(F(\omega_{n}^{k+n/2}) = L((\omega_{n}^{k+n/2})^2) + \omega_{n}^{k+n/2}R((\omega_{n}^{k+n/2})^2) \\ = L(\omega_{n}^{2k+n}) + \omega_{n}^{k+n/2}R(\omega_{n}^{2k+n})\\ = L(\omega_{n}^{2k}) + \omega_{n}^{k+n/2}R(\omega_{n}^{2k}) \\ = L(\omega_{n/2}^{k}) + \omega_{n}^{k+n/2}R(\omega_{n/2}^{k}) \\ = L(\omega_{n/2}^{k}) - \omega_{n}^{k}R(\omega_{n/2}^{k}) \\\) 可以观察到两次代入得出来的结果只差了一个正负号,也就是可用相似的代码解决.

此时若想求出 $(\omega_{n}^{k},y_k)$ $(0\leq k \leq n-1)$ ,可以减少一半,算出 \((\omega_{n/2}^{k},y_k)\) 即可通过上式得出 \((\omega_{n}^{k},y_k)\) 的值,考虑递归求解,这时就要保证每次递归的时候都能把新的 $ n$ 平均分成两半,即原来的 $n$ 必须是 $2$ 的正整数次幂 ,不足的用 $0$ 补全高次。

根据上述推导可以写出 $FFT$ 的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <cstdio>
#include <cmath>
#define Maxn 100
#define Pi acos(-1.0)
using namespace std;

struct complex{
    double x,y;
    complex(double xx=0,double yy=0){x=xx,y=yy;}
};
complex operator + (complex a,complex b){return complex(a.x+b.x,a.y+b.y);}
complex operator - (complex a,complex b){return complex(a.x-b.x,a.y-b.y);}
complex operator * (complex a,complex b){return complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}

void DFT(complex *f, int len){
	if (!len) return ;

	complex L[len+1], R[len+1];
	for (int k=0;k<len;k++){L[k]=f[k<<1]; R[k]=f[k<<1|1];} // 分奇偶

	DFT(L,len>>1); DFT(R,len>>1); //递归求解

	complex tmp(cos(Pi/len), sin(Pi/len)), buf(1, 0);
	for (int k=0; k<len; k++){
		f[k] = L[k] + buf*R[k];
		f[k+len] = L[k] - buf*R[k];
	 	buf = buf*tmp;
	}
}

int n, m;
complex b[Maxn];
int main(){
	scanf("%d",&n);
	for(int i=0; i<n; i++) 
		scanf("%lf", &b[i].x);

	for(m=1; m<n; m<<=1); //长度补到2的幂数

	DFT(b, m>>1);

	for(int i=0; i<m; ++i)
		printf("%.4f ",b[i].x);
}

IDFT

上述操作是把多项式化为了点值表示法,若想知道多项式的各系数还要把点值表示法转化为系数表示法。

系数表示法到点值表示法(DFT)有 \(Y(k) = \sum_{i=0}^{n-1}(\omega_n^k)^iF(i)\) IDFT 有 \(nF(k) = \sum_{i=0}^{n-1}(\omega_n^{-k})^iF(i)\) 仅有带入的单位根不同并多了一个 $n$ ,在此不做证明了(懒得打了),结合上一份代码有:

1
2
3
4
5
6
7
8
9
10
11
12
void FFT(complex *f, int len, int op){
	if (!len) return ;
	complex L[len+1], R[len+1];
	for (int k=0;k<len;k++){L[k]=f[k<<1]; R[k]=f[k<<1|1];} // 分奇偶
	FFT(L, len>>1, op); FFT(R, len>>1, op); //递归求解
	complex tmp(cos(Pi/len), op*sin(Pi/len)), buf(1, 0);
	for (int k=0; k<len; k++){
		f[k] = L[k] + buf*R[k];
		f[k+len] = L[k] - buf*R[k];
	 	buf = buf*tmp;
	}
}

$op$ 取 $ 1$是 $DFT$,取0是 $IDFT$,要记得 $IDFT$ 后要除以 $n$ 。

迭代优化写法

由于递归的常数太大,又有人发明了递归的写法,具体证明不做介绍,直接贴代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
int n, m, r[maxn<<2];
complex a[maxn<<2], b[maxn<<2];
void FFT(complex *f, short op){
	for(int i=0; i<n; i++) if(i < r[i]) 
		swap(f[i], f[r[i]]);
	for(int p=2; p<=n; p<<=1){
		int len = p>>1;
		complex tmp(cos(Pi/len), op*sin(Pi/len));
		for(register int k=0; k<n; k+=p){
			register complex buf(1,0);
			for(register int l=k; l<k+len; l++){
				register complex tt = buf*f[len+l] ;
				f[len+l] = f[l]-tt, f[l] = f[l]+tt;
				buf = buf*tmp;
			}
		}
	}
}
int main(){
	scanf("%d%d",&n, &m);
	for(int i=0; i<=n; i++) scanf("%lf", &a[i].x);
	for(int i=0; i<=m; i++) scanf("%lf", &b[i].x);
	for(m+=n, n=1; n<=m; n<<=1);
	for(int i=0; i<n; i++) // 预处理二进制反转
		r[i] = (r[i>>1]>>1)|((i&1)?n>>1:0);

	FFT(a, 1), FFT(b, 1);
	for(int i=0; i<n; ++i) a[i] = a[i]*b[i];
	FFT(a, -1);
	for(int i=0; i<=m; ++i)
		printf("%.0f ", fabs(a[i].y)/n);
}

假设我们需要求 F(x)G(x)F(x)∗G(x*)

复多项式 $P(x)=F(x)+iG(x)$ 也就是实部为 $F(x)$ , 虚部为 $G(x) $ .

则 $P(x)^2=(F(x)+iG(x))^2=F(x)^2-G(x)^2+2iF(x)G(x)$

发现 $P(x)^2$ 的虚部为 $2F(x)G(x)$

也就是说求出 $P(x)^2 $ 之后, 把它的虚部除以 $2n$ 即可.

减少依次调用写法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int main(){
	scanf("%d%d",&n, &m);
	for(int i=0; i<=n; i++) scanf("%lf", &a[i].x);
	for(int i=0; i<=m; i++) scanf("%lf", &a[i].y);
	for(m+=n, n=1; n<=m; n<<=1);
	for(int i=0; i<n; i++) // 预处理二进制反转
		r[i] = (r[i>>1]>>1)|((i&1)?n>>1:0);

	FFT(a, 1);
	for(int i=0; i<n; ++i) a[i] = a[i]*a[i];
	FFT(a, -1);
	for(int i=0; i<=m; ++i) // 注意这里多除了2
		printf("%.0f ", fabs(a[i].y)/n/2);
}

NTT

由于计算过程中会用到三角函数库,而且还要使用复数,计算比较慢,而且浮点数会丢失精度。若遇到一些让输出答案 $\mod p$ 的题目可以使用 $NTT$ 。

只是用原根做了替换,这里有 原根表 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>
#define LL long long
using namespace std;
const int maxn = 1e6+5;
#define G 3
#define mod 998244353

LL quick_pow(LL a, LL b=mod-2, LL t=mod){
	LL sum = 1;
	while(b){
		if(b&1) sum = sum*a % t;
		a = a*a % t; b >>= 1;
	}return sum;
}

int n, m, r[maxn<<2];
LL a[maxn<<2], b[maxn<<2], invn, invG;
void NTT(LL *f, short op){
	for(int i=0; i<n; i++) if(i > r[i]) 
		swap(f[i], f[r[i]]);
	for(int p=2; p<=n; p<<=1){
		int len = p>>1, w = quick_pow(op==1?G:invG, (mod-1)/p);
		for(int k=0; k<n; k+=p){
			LL buf = 1;
			for(int l=k; l<k+len; l++){
				int tt = buf*f[len+l] % mod;
				f[len+l] = (f[l]-tt+mod)%mod, f[l] = (f[l]+tt)%mod;
				buf = buf*w % mod;
			}
		}
	}
}
int main(){
	scanf("%d%d",&n, &m);
	for(int i=0; i<=n; i++) scanf("%lld", &a[i]);
	for(int i=0; i<=m; i++) scanf("%lld", &b[i]);
	for(m+=n, n=1; n<=m; n<<=1);
	for(int i=0; i<n; i++) 
		r[i] = (r[i>>1]>>1)|((i&1)?n>>1:0);
	invn = quick_pow(n); invG = quick_pow(G);

	NTT(a, 1), NTT(b, 1);
	for(int i=0; i<n; ++i) a[i] = a[i]*b[i] % mod;
	NTT(a, -1);
	for(int i=0; i<=m; ++i)
		printf("%lld ", a[i]*invn % mod);
}

多项式的各种算法

多项式乘法

上方的NTT即可完成,封装代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include<algorithm>
#include<cstdio>
#define LL long long
#define G 3
#define Maxn 1000500
#define mod 998244353
using namespace std;
inline int read();
LL powM(LL a, LL b=mod-2){
	LL ans=1, buf=a;
	while(b){
		if(b & 1) ans = (ans*buf) % mod;
		buf = (buf*buf) % mod; b>>=1;
	}return ans;
}
inline void print(LL *f,int len){
	for(int i=0; i<len; i++) 
		printf("%lld%c", f[i], " \n"[i==len-1]);
}

int n, m, r[Maxn<<2];
LL f[Maxn<<2], g[Maxn<<2], invn, invG;
void NTT(LL *f, int n, short op){
	for(int i=0; i<n; i++) if (r[i]<i) swap(f[r[i]], f[i]);
	for (int p=2; p<=n; p<<=1){
		int len=p>>1,
		w = powM(op==1?G:invG, (mod-1)/p);
		for(int k=0; k<n; k+=p){
			LL buf=1;
			for(int i=k; i<k+len; i++){
				int sav = f[len+i]*buf % mod;
				f[len+i] = (f[i]-sav+mod) % mod;
				f[i] = (f[i]+sav) % mod;
				buf = buf*w % mod;
			}
		}
	}
}
LL _g[Maxn<<2];
void times(LL *f, LL *gg, int len1, int len2, int limit){
	int n = 1;
	for(; n<len1+len2; n<<=1);
	LL *g = _g;
	for(int i=0; i<len2; i++) g[i] = gg[i];
	for(int i=len2; i<n; i++) g[i] = 0;
	invn = powM(n);
	for(int i=0; i<n; i++) r[i] = (r[i>>1]>>1)|((i&1)?n>>1:0);
	NTT(f,n,1); NTT(g,n,1);
	for(int i=0; i<n; ++i) f[i] = (f[i]*g[i]) % mod;
	NTT(f,n,-1);
	for(int i=0; i<limit; ++i) f[i] = f[i]*invn % mod;
	for(int i=limit; i<n; ++i) f[i] = 0;
}
int main(){
  invG = powM(G);
  scanf("%d%d", &n, &m); n++; m++;
  for(int i=0; i<n; i++) f[i] = read();
  for(int i=0; i<m; i++) g[i] = read();
  times(f, g, n, m, n+m-1);
  print(f, n+m-1);
  return 0;
}

多项式的逆(除法) (模数为NTT数)

公式 $R(X)=2R^(X)-R^(X)^2F(X)(mod\ x^n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
LL _r[Maxn<<2], _rr[Maxn<<2];
void invp(LL *f, int len){ //f的逆,依然放在f中
	LL *r=_r, *rr=_rr;
	int n=1; for(; n<len; n<<=1);
	for(int i=0; i<n; i++) rr[i] = r[i] = 0;
	rr[0] = powM(f[0]);
	for(int len=2; len<=n; len<<=1){
		for(int i=0; i<len; i++) r[i] = rr[i]*2 % mod;
		times(rr, rr, len/2, len/2, len); times(rr, f, len, len, len);
		for(int i=0; i<len; i++) rr[i] = (r[i]-rr[i]+mod) % mod;
	}
	for(int i=0; i<len; i++) f[i] = rr[i];
}
// print(f, m);

任意模数的多项式求逆

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<complex>
using namespace std;
const double Pi=acos(-1);
const int N=800100;
const int M=30000;
const int mod=1e9+7;
int n,F[N],G[N],l,r[N],tt,P1[N],P2[N];
complex<double> A1[N],A2[N],B1[N],B2[N],w[N],U[N];
int ksm(int x,int k)
{
    int s=1;for(;k;k>>=1,x=1ll*x*x%mod)
    if(k&1) s=1ll*s*x%mod; return s;            
}
void FFT(complex<double> *P,int op)
{
    for(int i=0;i<l;i++) if(i<r[i]) swap(P[i],P[r[i]]);
    for(int i=1;i<l;i<<=1)
        for(int j=0,p=i<<1;j<l;j+=p)
            for(int k=0;k<i;k++)
            {
                complex<double> W=w[l/i*k];W.imag()*=op;
                complex<double> X=P[j+k],Y=P[j+k+i]*W;
                P[j+k]=X+Y;P[j+k+i]=X-Y;
            }
}
void Work(complex<double> *A,complex<double> *B,int b)
{
    for(int i=0;i<l;i++) U[i]=A[i]*B[i]; FFT(U,-1);
    for(int i=0;i<l;i++) (P1[i]+=((long long)(U[i].real()/l+0.5)%mod*b%mod+mod)%mod)%=mod;
}
void MTT()
{
    for(int i=0;i<l;i++) A1[i].real()=P1[i]/M,B1[i].real()=P1[i]%M,A1[i].imag()=B1[i].imag()=0;
    for(int i=0;i<l;i++) A2[i].real()=P2[i]/M,B2[i].real()=P2[i]%M,A2[i].imag()=B2[i].imag()=0;
    for(int i=0;i<l;i++) P1[i]=0; FFT(A1,1);FFT(A2,1);FFT(B1,1);FFT(B2,1);
    Work(A1,A2,M*M);Work(A1,B2,M);Work(B1,A2,M);Work(B1,B2,1);
}
void GetInv(int *f,int *g,int n)
{
    if(n==1) {g[0]=ksm(f[0],mod-2);return;}
    GetInv(f,g,n>>1);
    for(tt=0,l=1;l<n*2;l<<=1) tt++;tt--;
    for(int i=0;i<l;i++) r[i]=(r[i>>1]>>1)|((i&1)<<tt),P1[i]=P2[i]=0;
    for(int i=0;i<l;i++) w[i].real()=cos(Pi/l*i),w[i].imag()=sin(Pi/l*i);
    for(int i=0;i<n;i++) P1[i]=f[i],P2[i]=g[i]; MTT(); MTT();
    for(int i=0;i<n;i++) g[i]=((2ll*g[i]%mod-P1[i])%mod+mod)%mod;
}
int main()
{
    scanf("%d",&n); for(int i=0;i<n;i++) scanf("%d",&F[i]);
    int m; for(m=1;m<n;m<<=1); GetInv(F,G,m);
    for(int i=0;i<n;i++) printf("%d ",G[i]);
}

多项式的牛顿迭代和微积分

多项式开方

公式: $B(x)=\frac{A(x)+B^∗(x)^2}{2B^∗(x)}$

1
2
3
4
5
6
7
8
9
10
11
void sprt_p(LL *a, int m)
	b[0] = 1;
    for(int len=2; len<=n; len<<=1){
        for(int i=0; i<len/2; i++) bb[i] = (2*b[i]) % mod;
        for(int i=len/2; i<len; i++) bb[i] = 0; //注意清空,不要用memset,不要吝啬for循环 
        invp(bb, len); times(b, b, len, len);
        for(int i=0; i<len; i++) b[i] = (a[i]+b[i]) % mod;
        times(b, bb, len, len);
    }
}
// print(b, m);

多项式带余除法

参考博客

[洛谷日报 傅里叶变换(FFT)学习笔记](https://www.luogu.org/blog/command-block/fft-xue-xi-bi-ji)
[洛谷日报 NTT与多项式全家桶](https://www.luogu.org/blog/command-block/ntt-yu-duo-xiang-shi-quan-jia-tong)

从多项式乘法到快速傅里叶变换

  1. 原根表:http://blog.miskcoo.com/2014/07/fft-prime-table 

# 使用单$作为行内数学公式分界符