YZOJ P3216 行商
时间限制:1000MS 内存限制:262144KB
难度:\(4.0\)
-
题目描述
有 \(n\) 个货物,每个货物都有各自的重量 \(w_i\) 和价值 \(c_i\),但是载重量仅为 \(m\) 。
挑选出一些货物,总重量不超过 \(m\),使价值之和最大。
-
输入格式
第一行,两个整数 \(n\),\(m\);
接下来 \(n\) 行,每行两个整数 \(w_i\),\(c_i\) 。
-
输出格式
一个整数 \(ans\) 。
-
样例输入
1 2 3 4 5 |
4 3 3 10 2 7 2 8 1 1 |
-
样例输出
1 |
10 |
-
数据规模与约定
\(1 \leq n \leq 10^6\),\(1 \leq m \leq 4^{31}\),\(1 \leq w_i \leq 3\),\(1 \leq c_i \leq 10^9\) 。
首先 \(w\) 这么小,肯定是把 \(3\) 种情况分开考虑。
将货物按照 \(w\) 分为 \(3\) 类,分别按从大到小进行排序,然后顺便记个前缀和 \(sum\) 。
发现如果先不考虑 \(w=3\) 的情况,那么有:选 \(i\) 件货物时,有 \(f_i = sum2_i + sum1_{m – 2i}\) 。
可以发现 \(sum2\) 是一个单调上升并且二阶导数 \(<0\) 的整点函数,\(sum1_{m – 2i}\) 是一个单调下降并且二阶导数 \(>0\) 的整点函数。
经过简单的数学推理,可以发现它们相加,结果会是一个单峰函数。
这样就枚举 \(w=3\) 取的数量,然后三分,时间复杂度 \(O(nlogn)\) 。
(p.s. \(m \leq 4^{31}\) 看似很大,但是当 \(m \geq 3n\) 所有的货物都可以装的下,所以并没有什么用)
(貌似还有一种 DP 的做法)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
#include <cstdio> #include <cstdlib> #include <cstring> #include <climits> #include <algorithm> #include <functional> #define _min(_a_,_b_) ((_a_)<(_b_)?(_a_):(_b_)) #define _max(_a_,_b_) ((_a_)>(_b_)?(_a_):(_b_)) #ifdef ONLINE_JUDGE char __B[1<<15],*__S=__B,*__T=__B; #define getchar() (__S==__T&&(__T=(__S=__B)+fread(__B,1,1<<15,stdin),__S==__T)?EOF:*__S++) #endif template<class T> inline T getnum() { register char c=0; while(!(c>='0' && c<='9')) c=getchar(); register T a=0; while(c>='0' && c<='9') { a*=10,a+=c-'0'; c=getchar(); } return a; } int lc1,lc2,lc3; long long nc1[1050505],nc2[1050505],nc3[1050505]; int main() { register int N=getnum<int>(); register long long M=getnum<long long>(),tc=0; register int tw=0; for(register int i=1;i<=N;i++) { register int w=getnum<int>(); if(w==1) nc1[++lc1]=getnum<int>(),tc+=nc1[lc1]; else if(w==2) nc2[++lc2]=getnum<int>(),tc+=nc2[lc2]; else nc3[++lc3]=getnum<int>(),tc+=nc3[lc3]; tw+=w; } if(tw<=M) { printf("%lld\n",tc); return 0; } std::sort(&nc1[1],&nc1[lc1+1],std::greater<int>()); std::sort(&nc2[1],&nc2[lc2+1],std::greater<int>()); std::sort(&nc3[1],&nc3[lc3+1],std::greater<int>()); for(register int i=1;i<=lc1;i++) nc1[i]+=nc1[i-1]; for(register int i=1;i<=lc2;i++) nc2[i]+=nc2[i-1]; for(register int i=1;i<=lc3;i++) nc3[i]+=nc3[i-1]; long long ans=0; for(register int s3=0;s3<=lc3;s3++) { register int sl=M-3*s3; if(sl<0) break; //printf("s3: %d left: %d\n",s3,sl); if((lc2<<1)+lc1 <= sl) ans=_max(ans,nc2[lc2]+nc1[lc1]+nc3[s3]); else { #define AnsPoint(_x) \ nc2[_x]+nc1[_min(sl-((_x)<<1),lc1)] register int l=0,r=_min(sl>>1,lc2); while(l+1<r) { //printf("divide %d %d\n",l,r); register int lmid=(l+r)>>1,rmid=(lmid+r)>>1; //printf("lmid %d(%d) rmid %d(%d)\n",lmid,sl-(lmid<<1),rmid,sl-(rmid<<1)); register long long ansl=AnsPoint(lmid),ansr=AnsPoint(rmid); //printf("ansl %lld ansr %lld\n",ansl,ansr); if(ansl>ansr) r=rmid; else l=lmid; } ans=_max(ans,_max(AnsPoint(l),AnsPoint(r))+nc3[s3]); } } printf("%lld\n",ans); return 0; } |