Post

男人八题之单调队列优化多重背包

多重背包的朴素解dp[i][j],借鉴01背包的优化思路不难想到一个将其优化成一维的解法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <iostream>
#include <algorithm>

//多重背包问题,同一个重量和value的物品有n件
//所以朴素的解法和完全背包是很类似的

using namespace std;

const int N = 110;


int dp[N];
int v[N],w[N],s[N];
int n,m;

int main(){
    cin >> n >> m;

    for(int i = 0;i < n;i++){
        int v,w,s;
        cin >> v >> w >> s;
        for(int j = m;j >= 0 ;j--){
            for(int k = 1;k <= s && k * v <= j;k++)
                dp[j] = max(dp[j],dp[j - k * v] + w * k);
        }
    }
    cout << dp[m] << endl;
}

但是当数据范围进一步扩大的时候,就需要思考一个更加狠的优化手段了。

不难想到,dp[m] = max(dp[m], dp[m-v] + w, dp[m-2\*v] + 2\*w, dp[m-3\*v] + 3\*w, ...)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
dp[0], dp[v],   dp[2*v],   dp[3*v],   ... , dp[k*v]
dp[1], dp[v+1], dp[2*v+1], dp[3*v+1], ... , dp[k*v+1]
dp[2], dp[v+2], dp[2*v+2], dp[3*v+2], ... , dp[k*v+2]
...
dp[j], dp[v+j], dp[2*v+j], dp[3*v+j], ... , dp[k*v+j]
显而易见,m 一定等于 k*v + j,其中  0 <= j < v
所以,我们可以把 dp 数组分成 j 个类,每一类中的值,都是在同类之间转换得到的
也就是说,dp[k*v+j] 只依赖于 { dp[j], dp[v+j], dp[2*v+j], dp[3*v+j], ... , dp[k*v+j] }
所以,我们可以把 dp 数组分成 j 个类,每一类中的值,都是在同类之间转换得到的
也就是说,dp[k*v+j] 只依赖于 { dp[j], dp[v+j], dp[2*v+j], dp[3*v+j], ... , dp[k*v+j] }
所以,我们可以得到
dp[j]    =     dp[j]
dp[j+v]  = max(dp[j] +  w,  dp[j+v])
dp[j+2v] = max(dp[j] + 2w,  dp[j+v] +  w, dp[j+2v])
dp[j+3v] = max(dp[j] + 3w,  dp[j+v] + 2w, dp[j+2v] + w, dp[j+3v])
...
但是,这个队列中前面的数,每次都会增加一个 w ,所以我们需要做一些转换
dp[j]    =     dp[j]
dp[j+v]  = max(dp[j], dp[j+v] - w) + w
dp[j+2v] = max(dp[j], dp[j+v] - w, dp[j+2v] - 2w) + 2w
dp[j+3v] = max(dp[j], dp[j+v] - w, dp[j+2v] - 2w, dp[j+3v] - 3w) + 3w
...
这样,每次入队的值是 dp[j+k*v] - k*w
最后dp[k]需要求的值就是dp[i - 1][q[hh]] + (k - q[hh]) / v * w

转化成实现就是

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <iostream>
#include <algorithm>
#include <cstring>

using namespace std;

const int N = 2e4 + 10;

int q[N],g[N],f[N];
int n,m;

int main(){
    cin >> n >> m;
    for(int i = 0;i < n;i++){
        int v,w,s;
        cin >> v >> w >> s;
        memcpy(g,f,sizeof f);
        for(int j = 0;j < v;j++){
            int hh = 0,tt= -1;
            for(int k = j;k <= m;k += v){
                if(hh <= tt && q[hh] < k - s * v) hh++;
                while(hh <= tt && g[q[tt]] - (q[tt] - j) / v * w <= g[k] - (k - j) / v * w ) --tt;
                q[++tt] = k;
                f[k] = g[q[hh]] + (k - q[hh]) / v * w;
            }
        }
    }
    cout << f[m] << endl;
}
This post is licensed under CC BY 4.0 by the author.

Trending Tags