昨天考试被教育了一波. 为了学习一下 \(T3\)的科技, 我就找到了这个远古时期的 \(cf\)题 (虽然最后 \(T3\) 还是不会写吧 \(QAQ\))
顾名思义, 这个题目其实可以建成一个费用流的模型. 我们用流量来限制区间个数, 用费用强迫它每次每次选择最大的区间就可以啦. 但是因为询问很多, 复杂度似乎不行, 于是就有了这种神奇的科技 -- 线段树模拟费用流.
在原先的费用流模型里, 我们有正反两种边, 而反向边的意义就在于, 在每一次增广的时候可以反悔以前的操作, 把局部最优向更大范围的局部更优优化.
参考反向边的原理, 我们可以想象出来: 如果对这个区间, 我们每次都取用最大子区间, 并在取用这个最大子区间以后将其价值变为负数, 不就可以模拟费用流的行为了嘛? 这样做的复杂度是 \(O(NMlogN)\)的, 可以解决更大数据范围的问题.
算法很好理解, 关键是千万不要把代码写挂 \(QwQ\), 真的不是很好调啊 \(TwT\)
- #include<bits/stdc++.h>
- using namespace std;
- struct dat {
- int s; //sum of sequence
- int lmx, lmxp; //left -> max_val && it's pos
- int lmn, lmnp; //left -> min_val && it's pos
- int rmx, rmxp; //righ -> max_val && it's pos
- int rmn, rmnp; //righ -> min_val && it's pos
- int smx, smxl, smxr; //sub -> max_val && it's pos (l, r)
- int smn, smnl, smnr; //sub -> min_val && it's pos (l, r)
- dat (int pos = 0, int val = 0){
- lmxp = lmnp = rmxp = rmnp = smxl = smxr = smnl = smnr = pos;
- s = lmx = lmn = rmx = rmn = smx = smn = val;
- // 对单个的点进行数据更新
- }
- }T[400010];
- dat operator + (dat l, dat r){
- dat u;
- u.s = l.s+r.s; // 先更新关于和的数据
- if (l.lmx> l.s + r.lmx) { //max_left's pos 是否越过 mid
- u.lmx = l.lmx;
- u.lmxp = l.lmxp;
- } else {
- u.lmx = l.s + r.lmx;
- u.lmxp = r.lmxp;
- }
- if (r.rmx> r.s + l.rmx) { //max_righ's pos 是否越过 mid
- u.rmx = r.rmx;
- u.rmxp = r.rmxp;
- } else {
- u.rmx = r.s + l.rmx;
- u.rmxp = l.rmxp;
- }
- if (l.lmn <l.s + r.lmn) { //min_left's pos 是否越过 mid
- u.lmn = l.lmn;
- u.lmnp = l.lmnp;
- } else {
- u.lmn = l.s + r.lmn;
- u.lmnp = r.lmnp;
- }
- if (r.rmn < r.s + l.rmn) { //min_righ's pos 是否越过 mid
- u.rmn = r.rmn;
- u.rmnp = r.rmnp;
- } else {
- u.rmn = r.s + l.rmn;
- u.rmnp = l.rmnp;
- }
- if (l.smx> r.smx) { // 最大子段 in left / righ
- u.smx = l.smx;
- u.smxl = l.smxl;
- u.smxr = l.smxr;
- } else {
- u.smx = r.smx;
- u.smxl = r.smxl;
- u.smxr = r.smxr;
- }
- if (l.rmx + r.lmx> u.smx){ // 最大子段是否越过 mid
- u.smx = l.rmx + r.lmx;
- u.smxl = l.rmxp;
- u.smxr = r.lmxp;
- }
- if (l.smn <r.smn) { // 最小子段 in left / righ
- u.smn = l.smn;
- u.smnl = l.smnl;
- u.smnr = l.smnr;
- } else {
- u.smn = r.smn;
- u.smnl = r.smnl;
- u.smnr = r.smnr;
- }
- if (l.rmn + r.lmn < u.smn) { // 最小子段是否跨过 mid
- u.smn = l.rmn + r.lmn;
- u.smnl = l.rmnp;
- u.smnr = r.lmnp;
- }
- return u;
- }
- #define ls (x << 1)
- #define rs (x << 1 | 1)
- void pushup (int x) {
- T[x] = T[ls] + T[rs];
- }
- int a[100010];
- void build (int l, int r, int x) {
- if (l == r) {
- T[x] = dat (l, a[l]);
- return;
- }
- int mid = (l + r)>> 1;
- build (l, mid, ls);
- build (mid + 1, r, rs);
- pushup (x);
- }
- int f[400010];
- void rev (int x) {
- dat &u = T[x];
- // max 变成 min
- swap (u.lmx, u.lmn);
- swap (u.lmxp, u.lmnp);
- swap (u.rmx, u.rmn);
- swap (u.rmxp, u.rmnp);
- swap (u.smx, u.smn);
- swap (u.smxl, u.smnl);
- swap (u.smxr, u.smnr);
- f[x] ^= 1;
- u.lmx *= -1;
- u.lmn *= -1;
- u.rmx *= -1;
- u.rmn *= -1;
- u.smx *= -1;
- u.smn *= -1;
- u.s *= -1;
- }
- void pushdown (int x) {
- if(f[x]) {
- rev (ls);
- rev (rs);
- f[x] = 0;
- }
- }
- void modify (int p, int v, int l, int r, int x) {
- if (l == r) {
- T[x] = dat (l, v);
- return;
- }
- pushdown (x);
- int mid = (l + r)>> 1;
- if (p <= mid) {
- modify (p, v, l, mid, ls);
- } else {
- modify (p, v, mid + 1, r, rs);
- }
- pushup (x);
- }
- void reverse (int L, int R, int l, int r, int x) {
- // 其实就是取用啦
- if (L <= l && r <= R) return rev (x);
- pushdown (x);
- int mid = (l + r)>> 1;
- if (L <= mid) reverse (L, R, l, mid, ls);
- if (mid <R) reverse (L, R, mid + 1, r, rs);
- pushup (x);
- }
- dat query (int L, int R, int l, int r, int x) {
- // 求 [l, r] 区间内的最大值嘛
- if (L <= l && r <= R) return T[x];
- pushdown (x);
- int mid = (l + r)>> 1;
- if (R <= mid) return query (L, R, l, mid, ls); // 如果区间全在左边
- if (mid <L) return query (L, R, mid + 1, r, rs); // 如果区间全在右边
- return query (L, R, l, mid, ls) + query (L, R, mid + 1, r, rs); // 跨 mid 了 QwQ
- }
- int L[30], R[30], top;
- int n, m, x, y, k, opt;
- int main () {
- cin>> n;
- for (int i = 1; i <= n; ++i) {
- cin>> a[i];
- }
- build (1, n, 1);
- cin>> m;
- for (int i = 1; i <= m; ++i) {
- cin>> opt>> x>> y;
- if (opt == 0) {
- modify (x, y, 1, n, 1); // 把点 x 的值改为 y
- } else {
- cin>> k; // 在 [x, y] 之间取 k 段的最大值
- int ans = 0;
- for (int j = 1; j <= k; ++j) {
- dat t = query (x, y, 1, n, 1);
- if (t.smx <= 0) break;
- // 选至多 k 段, 可以少选 !
- ans += t.smx;
- L[++top] = t.smxl, R[top] = t.smxr;
- reverse (L[top], R[top], 1, n, 1);
- }
- while (top) {
- reverse (L[top], R[top], 1, n, 1);
- top = top - 1;
- }
- cout << ans << endl;
- }
- }
- }
来源: http://www.bubuko.com/infodetail-2958276.html