YZOJ P3752 序列求差问题
时间限制:2000MS 内存限制:131072KB
出题人:Night        难度:\(6.0\)
- 
题目描述
 
有一个序列 \(x_1,x_2,\cdots,x_n\) 。
求有多少个从 \(1,2,\cdots,n\) 中取三个元素的排列 \((a,b,c)\) 满足 \(x_a=x_b-x_c\) 。
由于是排列,所以 \((a,b,c)\) 与 \((c,b,a)\) 视为两组解。
- 
输入格式
 
第一行一个整数 \(n\) 表示序列长度。
第二行为 \(n\) 个整数表示序列里的 \(n\) 个数。
- 
输出格式
 
一行一个正整数,表示答案。
- 
样例输入
 
| 
					 1 2  | 
						10 1 6 2 9 5 9 2 5 0 5  | 
					
- 
样例输出
 
| 
					 1  | 
						26  | 
					
- 
数据规模与约定
 
对于 \(20\%\) 的数据,\(1 \leq n \leq 500\);
对于 \(45\%\) 的数据,\(1 \leq n \leq 5000\);
对于 \(100\%\) 的数据,\(1 \leq n \leq 1000000\),\(0 \leq \left|x_i\right| \leq 100000\) 。
首先这个东西和 \(x_a+x_b=x_c\) 是等价的。
对于 \(n \leq 500\) 的数据,显然可以 \(O(n^3)\) 枚举 \((a,b,c)\) 暴力判断。
对于 \(n \leq 5000\) 的数据,可以记桶 \(cnt_i\) 表示 \(i=x_j\) 的不同 \(j\) 的个数,只要枚举 \((a,b)\) ,答案加上 \(cnt_{x_a+x_b}\) 即可,\(O(n^2)\)。
然后对于 \(100\%\) 的数据,出题人就发现 \(x_a+x_b=x_c\) 这个东西很像多项式乘法,因为可以把 \(cnt_{x_a}\) 和 \(cnt_{x_b}\) 贡献到 \(cnt_{x_c}\) 上。
所以就把 \(cnt\) 作为一个多项式,与它自己相乘,得到的就是答案了。
多项式乘法用 FFT 优化至 \(O(nlogn)\) 。
细节:
1,因为不能取两个相同的,所以 \(ans_{x_i+x_i}\) 要减一。
2,还有要特判一下 \(0\) 的情况,所以答案要减去 \(2 \times m \times (n-1)\) (其中 \(m\) 为 \(x\) 中 \(0\) 的个数)。
3,有负数,所以要整体偏移,\(x_a+diff+x_b+diff=x_c+2 \times diff\) ,注意多项式乘法的答案意义也有所变化。
4,因为答案可能超出 \(int\) ,所以不能使用 FWT/NTT 求解,而且 
			double 会被卡精度,所以要换成
			long double 继续,正确的做法是考虑到答案肯定不超过 \(A^3_{1000000}=999997000002000000\),可以使用 \(998244353\) 和 \(1004535809\) 两个模数分别 FNT 求一遍,然后再 CRT(中国剩余定理) 合并计算结果。
答案为 \(\sum ans_{x_i+2 \times diff}-2 \times m \times (n-1)\) 。
| 
					 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118  | 
						#include <cstdio> #include <cstdlib> #include <cstring> #include <climits> #include <cmath> #define DIFF 100001 #define _max(_a_,_b_) ((_a_)>(_b_)?(_a_):(_b_)) inline int getnum() {     register char c=0;     register bool neg=false;     while(!(c>='0' && c<='9'))         c=getchar(),neg|=(c=='-');     register int a=0;     while(c>='0' && c<='9')         a=a*10+c-'0',c=getchar();     return (neg?-1:1)*a; } struct _cm {     // a+bi     long double a,b;     _cm(long double na=0,long double nb=0){a=na,b=nb;}     _cm operator + (const _cm&o)const     {         return (_cm){a+o.a,b+o.b};     }     _cm operator - (const _cm&o)const     {         return (_cm){a-o.a,b-o.b};     }     _cm operator * (const _cm&o)const     {         return (_cm){a*o.a-b*o.b,a*o.b+b*o.a};     } }c[805050]; int flip[805050]; inline void FFT(register int len,_cm c[],bool inverse=false) {     _cm t;     for(register int i=0;i<len;i++)         if(i<flip[i])             t=c[flip[i]],c[flip[i]]=c[i],c[i]=t;     for(register int step=2;step<=len;step<<=1)     {         _cm wn=(_cm){std::cos((long double)2*M_PI/step*(inverse ? -1 : 1)), \         std::sin((long double)2*M_PI/step*(inverse ? -1 : 1))};         for(register int k=0;k<len;k+=step)         {             _cm w=(_cm){1};             for(register int i=k;i < k+(step>>1);i++)             {                 register int j=i+(step>>1);                 t=c[j]*w;                 c[j]=c[i]-t;                 c[i]=c[i]+t;                 w=w*wn;             }         }     }     if(inverse)         for(register int i=0;i<len;i++)             c[i].a/=len; } int a[1050505]; long long cnt[805050]; int main() {     register int N=getnum(),mxa=0,cnt0=0;     for(register int i=1;i<=N;i++)     {         a[i]=getnum()+DIFF;         mxa=_max(a[i],mxa),cnt0+=(a[i]==DIFF);         cnt[a[i]]++;     }     mxa<<=1;     register int len=1,dig=0;     while(len<mxa)         len<<=1,dig++;     for(register int i=0;i<len;i++)     {         c[i]=(_cm){(long double)((i<<1)<=mxa ? cnt[i] : 0)};         flip[i]=(flip[i>>1]>>1)|((i&1)<<(dig-1));     }     //printf("--------------- A\n");     //for(register int i=0;i<len;i++)     //    printf("%d: %.2Lf+%.2Lfi\n",i,c[i].a,c[i].b);     FFT(len,c);     //printf("--------------- B\n");     //for(register int i=0;i<len;i++)     //    printf("%d: %.2Lf+%.2Lfi\n",i,c[i].a,c[i].b);     for(register int i=0;i<len;i++)         c[i]=c[i]*c[i];     FFT(len,c,true);     //printf("--------------- C\n");     //for(register int i=0;i<len;i++)     //    printf("%d: %.2Lf+%.2Lfi\n",i,c[i].a,c[i].b);     for(register int i=1;i<=mxa;i++)         cnt[i]=(long long)(c[i].a+0.49999);     long long ans=-(long long)cnt0*(N-1)<<1;     for(register int i=1;i<=N;i++)         cnt[a[i]<<1]--;     for(register int i=1;i<=N;i++)         ans+=cnt[a[i]+DIFF];     printf("%lld\n",ans);     return 0; }  |