线段树 (SegmentTree) 是一种基于分治思想的二叉树结构, 用于区间上进行信息统计. 与按照二进制位进行划分的树状数组相比, 线段树是一种更加通用的结构.
性质
线段树每个节点都代表一个区间.
线段树具有唯一的根节点, 代表的区间是整个统计范围.
线段树的每个叶节点都代表一个长度为 1 的元区间.
对于每个内部节点[l, r], 它的左子节点是[l, mid], 右子节点是[mid+1, r], 其中 mid = (l + r)>> 1.
建树
每个叶节点 t[i, i]维护 a[i]的值, 从而使信息从下向上传递信息.
单点更改
从根节点出发, 找到区间 [x, x] 的叶节点, 然后从下向上更新.
区间查询
找到完全覆盖当前节点的区间, 立即回溯.
若左子节点有重合的部分, 则访问左子节点.
若右子节点有重合的部分, 则访问右子节点.
延迟标记
区间修改的指令中, 某个区间的结点改变, 整棵子树中的所有节点存储的信息都会发生变化, 修改的时间复杂度将会增加到 O(N).
我们发现, 若我们将区间的整棵子树进行更新, 却始终没有进行查询, 那么整棵子树的更新都是徒劳的. 所以, 我们可以在修改指令时找到完全覆盖的区间后立即返回, 只是在其中增加一个标记, 表示此节点已经更改, 但子树都未更新.
此后的更改和查询命令中, 在向下访问之前, 我们先将未更新的结点向下传递, 然后清除当前结点的标记. 这样一来, 时间复杂度仍可维持在 O(logN).
模板:
题目描述
如题, 已知一个数列, 你需要进行下面两种操作:
1. 将某区间每一个数加上 x
2. 求出某区间每一个数的和
输入输出格式
输入格式:
第一行包含两个整数 N,M, 分别表示该数列数字的个数和操作的总个数.
第二行包含 N 个用空格分隔的整数, 其中第 i 个数字表示数列第 i 项的初始值.
接下来 M 行每行包含 3 或 4 个整数, 表示一个操作, 具体如下:
操作 1: 格式: 1 x y k 含义: 将区间 [x,y] 内每个数加上 k
操作 2: 格式: 2 x y 含义: 输出区间 [x,y] 内每个数的和
输出格式:
输出包含若干行整数, 即为所有操作 2 的结果.
- #include<iostream>
- using namespace std;
- const int SIZE = 100005;
- struct SegmentTree{
- int l, r;
- long long add, sum;
- }t[SIZE*4];
- int a[SIZE], n, m;
- void build(int p, int l, int r){
- t[p].l = l, t[p].r = r;
- if(l == r){
- t[p].sum = a[l];
- return;
- }
- int mid = (l + r)>> 1;
- build(p*2, l, mid);
- build(p*2+1, mid+1, r);
- t[p].sum = t[p*2].sum + t[p*2+1].sum;
- }
- void spread(int p){
- if(t[p].add){ // 结点 p 有标记
- t[p*2].sum += t[p].add * (t[p*2].r-t[p*2].l+1);
- t[p*2+1].sum += t[p].add * (t[p*2+1].r-t[p*2+1].l+1);
- t[p*2].add += t[p].add; // 延迟标记
- t[p*2+1].add += t[p].add;
- t[p].add = 0;
- }
- }
- void change(int p, int l, int r, int d){
- if(l <= t[p].l && r>= t[p].r){
- t[p].sum += (long long)d * (t[p].r - t[p].l + 1);
- t[p].add += d;
- return;
- }
- spread(p);
- int mid = (t[p].l + t[p].r)>> 1;
- if(l <= mid) change(p*2, l, r, d);
- if(r> mid) change(p*2+1, l, r, d);
- t[p].sum = t[p*2].sum + t[p*2+1].sum;
- }
- long long ask(int p, int l, int r){
- if(l <= t[p].l && r>= t[p].r) return t[p].sum;
- spread(p);
- int mid = (t[p].l + t[p].r)>> 1;
- long long val = 0;
- if(l <= mid) val += ask(p*2, l, r);
- if(r> mid) val += ask(p*2+1, l, r);
- return val;
- }
- int main(){
- cin>> n>> m;
- for(int i=1; i<=n; i++) cin>> a[i];
- build(1, 1, n);
- while(m--){
- int k, l, r, h;
- cin>> k>> l>> r;
- switch(k){
- case 1:
- cin>> h;
- change(1, l, r, h);
- break;
- case 2:
- cout << ask(1, l, r) <<"\n";
- break;
- }
- }
- return 0;
- }
来源: http://www.bubuko.com/infodetail-3115807.html