之前向学习了一个 $\text{FFT}$ 的优化, 但是像我这么弱的人每次打 $\text{FFT}$ 板子的时候都会忘记这个东西, 在这里记一下.
? 我们知道普通的 $\text{FFT}$ 会用到原根 $\omega_n^0,\omega_n^1\cdots\omega_n^{n-1}$ 然后这些东西会在枚举步长的时候通过 $\omega_n = e^{\frac{2\pi}{n}}$ 和 $e^{\theta i} = \cos \theta + i\sin \theta$ 这两个公式一次一次算出来.
? 然而我们知道, 调用三角函数是非常慢的, 每次计算的时候, 即使你是手写的 $\text{complex}$ 也会非常慢, 这就使得这种 $\text{FFT}$ 的常数巨大无比.
? 所以我们就预处理一下每次需要用到的 $\omega$ , 把每一种步长需要用到的 $\omega$ 扔到同一个数组 $W$ 里, 有每种步长的 $\omega$ 连续. 而因为 $\sum_{i=0}^{n} 2^i = 2^{i + 1} - 1$ , 所以每次需要访问步长为 $s$ 的 $\omega$ 时候只要访问 $W[s]$ 就可以了, 将一个指针指向他, 而后面的只要把指针一步一步往后移即可.
? 这是 $\text{DFT}$ 的时候用的, 但是我们知道 $\text{IDFT}$ 的时候用的 $\omega$ 和 $\text{DFT}$ 的时候是不一样的.
? 然而我们不需要重新处理 $\text{IDFT}$ 用的 $\omega$ , 只需要把需要 $\text{FFT}$ 的 $A$ 从 $1$ 到 $n - 1$ 的值 $\text{reverse}$ 一下就行了. 原理是本来 $\text{IDFT}$ 的时候需要把 $\omega$ 翻过来, 但是那个有点麻烦, 于是我们就把 $A$ 给翻过来就行了. 由于 $\text{FFT}$ 可以被理解为一个特殊的矩阵乘法, 所以你顺着搞下来和反着搞回去最后的结果是一样的, 所以它是对的.
? 然后下面贴了一道水题的代码来帮助理解:
例:? 求有多少个从 $1,2,\cdots,n$ 中取三个元素的排列 $(a,b,c)$ 满足 $x_a=x_b-x_c$.? 由于是排列, 所以 $(a,b,c)$ 与 $(c,b,a)$ 视为两组解.
- #include <algorithm>
- #include <cstdio>
- #include <cmath>
- #include <cstdlib>
- #include <cstring>
- #include <ctime>
- #include <iostream>
- #include <queue>
- #include <set>
- #include <stack>
- #define R register
- #define ll long long
- #define db double
- #define ld long double
- #define sqr(_x) (_x) * (_x)
- #define Cmax(_a, _b) ((_a) <(_b) ? (_a) = (_b), 1 : 0)
- #define Cmin(_a, _b) ((_a)> (_b) ? (_a) = (_b), 1 : 0)
- #define Max(_a, _b) ((_a)> (_b) ? (_a) : (_b))
- #define Min(_a, _b) ((_a) <(_b) ? (_a) : (_b))
- #define Abs(_x) (_x < 0 ? (-(_x)) : (_x))
- using namespace std;
- namespace Dntcry
- {
- inline int read()
- {
- R int a = 0, b = 1; R char c = getchar();
- for(; c < '0' || c> '9'; c = getchar()) (c == '-') ? b = -1 : 0;
- for(; c>= '0' && c <= '9'; c = getchar()) a = (a <<1) + (a << 3) + c - '0';
- return a * b;
- }
- inline ll lread()
- {
- R ll a = 0, b = 1; R char c = getchar();
- for(; c < '0' || c> '9'; c = getchar()) (c == '-') ? b = -1 : 0;
- for(; c>= '0' && c <= '9'; c = getchar()) a = (a <<1) + (a << 3) + c - '0';
- return a * b;
- }
- const int Maxn = 1000010, Maxl = 600010, lim = 100000;
- const ld pi = acos(-1);
- struct Complex
- {
- ld real, imag;
- Complex operator + (const Complex &b) const
- {
- return (Complex) {real + b.real, imag + b.imag};
- }
- Complex operator - (const Complex &b) const
- {
- return (Complex) {real - b.real, imag - b.imag};
- }
- Complex operator * (const Complex &b) const
- {
- return (Complex) {real * b.real - imag * b.imag, b.real * imag + real * b.imag};
- }
- }C[Maxl], A[Maxl], w[Maxl], wl;
- int n, m, x[Maxn], Cnt[Maxl], len, bit, rev[Maxl], zero;
- ll Ans[Maxn], Sum;
- void Get_Rev(R int bit)
- {
- for(R int i = 0; i < len; i++)
- rev[i] = (rev[i>> 1]>> 1) | ((i & 1) <<bit - 1);
- return ;
- }
- void FFT(R Complex *K, R ld DFT)
- {
- for(R int i = 0; i < len; i++) if(i < rev[i]) swap(K[i], K[rev[i]])
- R Complex *W;
- for(R int i = 2; i <= len; i <<= 1)
- {
- for(R int j = 0, step = i>> 1; j <len; j += i)
- {
- W = w + step;
- for(R int k = j; k < j + step; W++, k++)
- {
- R Complex G = K[k], H = *W * K[k + step];
- K[k] = G + H;
- K[k + step] = G - H;
- }
- }
- }
- if(DFT == -1.0)
- for(R int i = 0; i < len; i++)
- K[i].real /= 1.0 * len, K[i].imag /= 1.0 * len;
- return ;
- }
- int Main()
- {
- n = read();
- for(R int i = 1; i <= n; i++)
- {
- x[i] = read(); if(!x[i]) zero++;
- x[i] += lim, m = Max(m, x[i]);
- Cnt[x[i]]++;
- } m++;
- for(bit = 0, len = 1; (1 << bit) < (m << 1); bit++) len <<= 1;
- R int tmp = len>> 1;
- w[tmp] = (Complex) {1.0, 0.0};
- wl = w[++tmp] = (Complex) {cos(2.0 * pi / len), sin(2.0 * pi / len)};
- for(tmp++; tmp <len; tmp++) w[tmp] = w[tmp - 1] * wl;
- for(R int i = (len>> 1) - 1; i; i--) w[i] = w[i << 1];
- Get_Rev(bit);
- for(R int i = 0; i < m; i++) A[i] = (Complex) {1.0 * Cnt[i], 0.0};
- FFT(A, 1.0);
- C[0] = A[0] * A[0];
- for(R int i = 1; i < len; i++) C[i] = A[len - i] * A[len - i];
- FFT(C, -1.0);
- for(R int i = 0; i < len; i++) Ans[i] = (ll)(C[i].real + 0.5);
- for(R int i = 1; i <= n; i++) Ans[x[i] << 1]--;
- for(R int i = 1; i <= n; i++) Sum += Ans[x[i] + lim];
- Sum -= 2ll * zero * (n - 1);
- printf("%lld\n", Sum);
- return 0;
- }
- }
- int main()
- {
- return Dntcry :: Main();
- }
来源: http://www.bubuko.com/infodetail-2970484.html