数据结构-线段树

数据结构-线段树

参考资料

暂无

线段树是所有 RMQ 中最常用的数据结构。

功能:区间修改区间查询。不止最值、求和。只要可递推的值都可以构造线段树。

如果区间大小为 (n),线段树有 (cnt) 个节点,那么 (2n-1le cnt<4n)

节点

对于每个节点 (x),和堆类似,父亲节点为 (x>>1)(即 (x/2) 下取整的位运算方法,位运算方便而且快),左儿子为 (x<<1)(即 (2x)),右儿子为 (x<<1|1)(即 (2x+1))。

同时每个节点对应一段区间,所以叫线段树。节点 (1) 对应的区间为 (1sim n)。设一个节点对应的区间为 (lsim r),那么它的左儿子对应的区间就是 (lsim mid),其中 (mid=(l+r)>>1),右儿子区间为 (mid+1sim r)。如果一个节点对应单点区间,就没有儿子。

同时每个节点对应一个值,即该区间的 RMQ 值。如果是求最值问题,就表示该区间最大值;如果是求和问题,就表示该区间的和。

操作(单点修改区间查询)

一个线段树是求和还是求最值或者求别的东西,取决于 (pushup(k)) 函数,其中 (k) 为节点编号,时间复杂度 (O(1))

void pushup(int k){v[k]=max(v[k<<1],v[k<<1|1]);}//求最大值

根据原序列构造初始的线段树用 (build()) 函数,单点节点上的值就为单点的值,递归从下到上构造,时间复杂度 (O(nlog n))

void build(int k=1,int l=1,int r=n){//表示外部应用默认k=1,l=1,r=n      if(l==r){v[k]=a[l];return;} //单点节点      build(k<<1,l,mid),build(k<<1|1,mid+1,r); //递归构造      pushup(k); //递推  }

先讲单点修改(加上 (y)),只需与 (build()) 函数类似的递归操作即可,如果到达单点节点,就修改,不走那些跟查询单点没关系的区间、别忘了修改完后也要递推,时间复杂度 (O(log n))

void fix(int x,int y,int k=1,int l=1,int r=n){      if(l==x&&r==x){v[k]+=y;return;} //单点修改      if(mid>=x) fix(x,y,k<<1,l,mid); //递归左儿子      else fix(x,y,k<<1|1,mid+1,r); //递归右儿子      pushup(k);//递推  }

区间查询,如果单前节点在查询区间内,就返回值。否则,递归左儿子右儿子,递推得区间查询值。时间复杂度 (O(log n)),因为只会走相关的 (log n) 个节点。

int fmax(int x,int y,int k=1,int l=1,int r=n){      if(x<=l&&r<=y) return v[k]; //在查找区间内,返回值      int res=0;      if(mid>=x) res=max(res,fmax(x,y,k<<1,l,mid));      if(mid<y) res=max(res,fmax(x,y,k<<1|1,mid+1,r));      return res;  }

总时间复杂度 (O(nlog n)) ,全代码:

#include <bits/stdc++.h>  using namespace std;  const int N=1e5+10;  int n,m,a[N];  namespace Sumtree{      #define mid ((l+r)>>1)      int v[N<<2];      void pushup(int k){v[k]=max(v[k<<1],v[k<<1|1]);}      void build(int k=1,int l=1,int r=n){          if(l==r){v[k]=a[l];return;}          build(k<<1,l,mid),build(k<<1|1,mid+1,r);          pushup(k);      }      void fix(int x,int y,int k=1,int l=1,int r=n){          if(l==x&&r==x){v[k]+=y;return;}          if(mid>=x) fix(x,y,k<<1,l,mid);          else fix(x,y,k<<1|1,mid+1,r);          pushup(k);      }      int fmax(int x,int y,int k=1,int l=1,int r=n){          if(x<=l&&r<=y) return v[k];          int res=0;          if(mid>=x) res=max(res,fmax(x,y,k<<1,l,mid));          if(mid<y) res=max(res,fmax(x,y,k<<1|1,mid+1,r));          return res;      }      #undef mid  }using namespace Sumtree;  int main(){      scanf("%d%d",&n,&m);      for(int i=1;i<=n;i++)          scanf("%d",a+i);      build();      for(int i=1,x,y,z;i<=m;i++){          scanf("%d%d%d",&x,&y,&z);          if(x==1) fix(y,z);          else printf("%dn",fmax(y,z));      }      return 0;  }

线段树如果只能单点修改区间查询,代码还这么长,就没人用他了。所以可想而知,线段树还可以区间修改,区间查询。

操作(区间修改区间查询)

先看如何区间修改,初学者会这么写:

void fix(int x,int y,int z,int k=1,int l=1,int r=n){      if(x<=l&&r<=y){v[k]+=z;return;}      if(mid>=x) fix(x,y,z,k<<1,l,mid);      if(mid<y) fix(x,y,z,k<<1|1,mid+1,r);      pushup(k);  }

问题是这样的话对于每个区间属于 ([x,y]) 的节点,它的子节点就会没被修改。

初学者还可能这么写:

void fix(int x,int y,int z,int k=1,int l=1,int r=n){      if(l==r){v[k]+=z;return;}      if(mid>=x) fix(x,y,z,k<<1,l,mid);      if(mid<y) fix(x,y,z,k<<1|1,mid+1,r);      pushup(k);  }

问题在于时间复杂度为 (O(n))

那么区间修改的主角就要出场了——懒标记((texttt{lazytag}))。对于每个节点,多加一个值,(mk[])(mk[x]) 表示 (x) 节点的标记。每次修改在前面那个初学者的代码上加上终止区间懒标记。

void fix(int x,int y,int z,int k=1,int l=1,int r=n){      if(x<=l&&r<=y){v[k]+=z,mk[k]+=z;return;}      pushdown(k);      if(mid>=x) fix(x,y,z,k<<1,l,mid);      if(mid<y) fix(x,y,z,k<<1|1,mid+1,r);      pushup(k);  }

这时你注意到了上方代码第 (3) 行有一个 (pushdown(k)),那就是一个专门用来处理懒标记的函数,负责把标记下放,时间复杂度为 (O(1))

void pushdown(int k){      if(!mk[k]) return;      v[k<<1]+=mk[k],v[k<<1|1]+=mk[k];      mk[k<<1]+=mk[k],mk[k<<1|1]+=mk[k],mk[k]=0;  }

有了它,区间修改就没必要一直修改到单点了,可以修改到所属区间,然后记下懒标记。下次到这个区间的时候把它 (pushdown) 下放。

然后区间修改区间查询的代码就是这样:

#include <bits/stdc++.h>  using namespace std;  const int N=1e5+10;  int n,m,a[N];  namespace Sumtree{      #define mid ((l+r)>>1)      int v[N<<2],mk[N<<2];      void pushup(int k){v[k]=max(v[k<<1],v[k<<1|1]);}      void pushdown(int k){          if(!mk[k]) return;          v[k<<1]+=mk[k],v[k<<1|1]+=mk[k];          mk[k<<1]+=mk[k],mk[k<<1|1]+=mk[k],mk[k]=0;      }      void build(int k=1,int l=1,int r=n){          mk[k]=0;          if(l==r){v[k]=a[l];return;}          build(k<<1,l,mid),build(k<<1|1,mid+1,r);          pushup(k);      }      void fix(int x,int y,int z,int k=1,int l=1,int r=n){          if(x<=l&&r<=y){v[k]+=z,mk[k]+=z;return;}          pushdown(k);          if(mid>=x) fix(x,y,z,k<<1,l,mid);          if(mid<y) fix(x,y,z,k<<1|1,mid+1,r);          pushup(k);      }      int fmax(int x,int y,int k=1,int l=1,int r=n){          if(x<=l&&r<=y) return v[k];          pushdown(k);          int res=0;          if(mid>=x) res=max(res,fmax(x,y,k<<1,l,mid));          if(mid<y) res=max(res,fmax(x,y,k<<1|1,mid+1,r));          return res;      }      #undef mid  }using namespace Sumtree;  int main(){      scanf("%d%d",&n,&m);      for(int i=1;i<=n;i++)          scanf("%d",a+i);      build();      for(int i=1,x,y,z;i<=m;i++){          scanf("%d",&x);          if(x==1) scanf("%d%d%d",&x,&y,&z),fix(x,y,z);          else scanf("%d%d",&x,&y),printf("%dn",fmax(x,y));      }      return 0;  }

时间复杂度还是 (O(nlog n)) 的。

线段树有个经典例题,可以帮助你弄懂线段树的其他操作。

[USACO08FEB]酒店Hotel

第一行输入 (n)(m)(n) 代表有 (n) 个房间,编号为 (1sim n),开始都为空房。(m) 表示以下有 (m) 行操作,以下每行先输入一个数 (i),表示一种操作:

(i) 为1,表示查询房间。再输入一个数 (x),表示在 (1sim n) 房间中找到长度为 (x) 的连续空房,输出连续 (x) 个房间中左端的房间号,尽量让这个房间号最小。若找不到长度为 (x) 的连续空房,输出(0)。并且在这 (x) 个空房间中住上人。

(i)(2),表示退房,再输入两个数 (x)(y) 代表 房间号 (xsim x+y-1) 退房,即让房间为空。

讲解:

那么这题中每个线段树节点需要有四个值:

(texttt{lf[k]:})(k) 这个节点区间从左边开始连续空房数。
(texttt{rt[k]:})(k) 这个节点区间从右边开始连续空房数。
(texttt{v[k]:})(k) 这个节点区间内最长的连续空房数。
(texttt{mk[k]:})(k) 这个节点退房、住人区间修改懒标记。

所以有递推式(其中 (?) 为三目运算符):

void pushup(int k,int l,int r){      int mid=(l+r)>>1;      lf[k]=lf[k<<1]==mid-l+1?lf[k<<1]+lf[k<<1|1]:lf[k<<1];      rt[k]=rt[k<<1|1]==r-mid?rt[k<<1|1]+rt[k<<1]:rt[k<<1|1];      v[k]=max(max(v[k<<1],v[k<<1|1]),rt[k<<1]+lf[k<<1|1]);  }

可以这么初始化:

void build(int k=1,int l=1,int r=n){      mk[k]=-1;      if(l==r){lf[k]=rt[k]=v[k]=1;return;}      int mid=(l+r)>>1;      build(k<<1,l,mid),build(k<<1|1,mid+1,r);      pushup(k,l,r);  }

重点在于怎么查询。如下代码,(find(x,k,l,r)) 表示寻找 (k) 这个节点区间里寻找最左的连续 (x) 空房。

int find(int x,int k=1,int l=1,int r=n){      if(v[k]<x) return -1; //如果区间内最长连续空房小于x,逃      int mid=(l+r)>>1;      pushdown(k,l,r);//千万别忘了pushdown      if(v[k<<1]>=x) return find(x,k<<1,l,mid); //如果左儿子有满足要求的区间,返回左儿子满足要求的区间左端点      if(rt[k<<1]+lf[k<<1|1]>=x) return mid-rt[k<<1]+1;//如果满足区间横跨左右儿子区间,返回横跨区间左端点      return find(x,k<<1|1,mid+1,r);//返回右儿子满足区间左端点  }

可以发现,这个代码的时间复杂度也是 (O(nlog n)) 的。

蒟蒻的 (color{#44cc00}texttt{AC}) 代码:

#include <bits/stdc++.h>  using namespace std;  const int N=5e4+10;  int n,m;  namespace sumtree{      int lf[N<<2],rt[N<<2],v[N<<2],mk[N<<2];      void pushup(int k,int l,int r){          int mid=(l+r)>>1;          lf[k]=lf[k<<1]==mid-l+1?lf[k<<1]+lf[k<<1|1]:lf[k<<1];          rt[k]=rt[k<<1|1]==r-mid?rt[k<<1|1]+rt[k<<1]:rt[k<<1|1];          v[k]=max(max(v[k<<1],v[k<<1|1]),rt[k<<1]+lf[k<<1|1]);      }      void pushdown(int k,int l,int r){          if(mk[k]==-1) return;          int mid=(l+r)>>1;          lf[k<<1]=rt[k<<1]=v[k<<1]=(!mk[k])*(mid-l+1);          lf[k<<1|1]=rt[k<<1|1]=v[k<<1|1]=(!mk[k])*(r-mid);          mk[k<<1]=mk[k<<1|1]=mk[k],mk[k]=-1;      }      void build(int k=1,int l=1,int r=n){          mk[k]=-1;          if(l==r){lf[k]=rt[k]=v[k]=1;return;}          int mid=(l+r)>>1;          build(k<<1,l,mid),build(k<<1|1,mid+1,r);          pushup(k,l,r);      }      int find(int x,int k=1,int l=1,int r=n){          if(v[k]<x) return -1;          if(l==r) return l;          int mid=(l+r)>>1;          pushdown(k,l,r);          if(v[k<<1]>=x) return find(x,k<<1,l,mid);          if(rt[k<<1]+lf[k<<1|1]>=x) return mid-rt[k<<1]+1;          return find(x,k<<1|1,mid+1,r);      }      void clear(int x,int y,int k=1,int l=1,int r=n){          if(x<=l&&r<=y){              lf[k]=rt[k]=v[k]=r-l+1;              mk[k]=0; return;          }          pushdown(k,l,r);          int mid=(l+r)>>1;          if(mid>=x) clear(x,y,k<<1,l,mid);          if(mid<y) clear(x,y,k<<1|1,mid+1,r);          pushup(k,l,r);      }      void full(int x,int y,int k=1,int l=1,int r=n){          if(x<=l&&r<=y){              lf[k]=rt[k]=v[k]=0;              mk[k]=1; return;          }          pushdown(k,l,r);          int mid=(l+r)>>1;          if(mid>=x) full(x,y,k<<1,l,mid);          if(mid<y) full(x,y,k<<1|1,mid+1,r);          pushup(k,l,r);      }  }using namespace sumtree;  int main(){      scanf("%d%d",&n,&m);      build();      for(int i=1;i<=m;i++){          int op,x,y;          scanf("%d",&op);          if(op==1){              scanf("%d",&y);              if((x=find(y))==-1) puts("0");              else {                  printf("%dn",x);                  full(x,x+y-1);              }          } else {              scanf("%d%d",&x,&y);              clear(x,x+y-1);          }      }      return 0;  }

关于线段树有很多后续知识,如线段树合并,线段树分裂,可持久化线段树(主席树)等,学习千万不能停止脚步。

同时,线段树的题目千变万化,建议多练练线段树的题。

祝大家学习愉快!