KD-tree学习笔记(超全!)
- 2019 年 10 月 5 日
- 筆記
因为之前找不到全的博客,唯一的一篇码风比较毒瘤。。。
所以我就来写了
K-D树
大概是高维二叉树吧
每次按一个维度对超空间内的点进行二分划分
树上存左右节点和这个节点所代表的的点
更新信息
我们保存几个信息:
- size 在重构的时候有用
- min[2],max[2],,就是子树中每个维度的值的最值,即处理出当前节点所代表的空间
- 题目中的其他信息,比如区间总权值
void push_up(int now){ int l=ls[now],r=rs[now];t[now].sz=t[l].sz+t[r].sz+1;t[now].sum=t[l].sum+t[r].sum+t[now].c.cnt; for(register int i=0;i<=1;i++){ t[now].mi[i]=t[now].mx[i]=t[now].c.x[i]; if(l) t[now].mi[i]=min(t[now].mi[i],t[l].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[l].mx[i]); if(r) t[now].mi[i]=min(t[now].mi[i],t[r].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[r].mx[i]); } }
建树
递归进行,每次选择一个维度进行划分,每次(O(N))共大约(log n)层
注意应用(nth)_(element)函数,要定义point之间的比较符号
int operator < (point a,point b){ return a.x[D]<b.x[D]; } inline int build(int l,int r,int d){ if(l>r) return 0; int now=newnode(),mid=(l+r)>>1; D=d,nth_element(p+l,p+mid,p+r+1); t[now].c=p[mid],ls[now]=build(l,mid-1,d^1),rs[now]=build(mid+1,r,d^1); push_up(now); return now; }
插入
另一种形式的建树。。。
就是找到对应的区间加点就行了,跟平衡树差不多,注意push_up
inline void insert(int &now,point p,int d){ if(!now){ now=newnode();ls[now]=rs[now]=0;t[now].c=p;push_up(now);return; } if(p.x[d]<=t[now].c.x[d]) insert(ls[now],p,d^1); else insert(rs[now],p,d^1); push_up(now);check(now,d); }
查询
每次到达一个节点,首先判断这个节点是不是被查询区间完全包含
如果是,统计答案并退出
然后分三部分查询:本节点,左右儿子区间
本节点直接判断,左右儿子区间判断是否和查询区间有交集,有就递归
有论文证明了矩形操作里面复杂度是(n ^{frac{k-1}{k}})的,k是维度数
这个复杂度很大,一般用在k=2的时候
对于高维我们可以排序去掉一维或者CDQ分治
struct sqr{ int x1,x2,y1,y2; }q; int chkin(int now,sqr tp){ return (!(t[now].mx[0]<tp.x1||t[now].mi[0]>tp.x2||t[now].mx[1]<tp.y1||t[now].mi[1]>tp.y2)); } int totalin(int now,sqr tp){ return (t[now].mx[0]<=tp.x2&&t[now].mi[0]>=tp.x1&&t[now].mx[1]<=tp.y2&&t[now].mi[1]>=tp.y1); } int ptin(point a,sqr b){ return (b.x2>=a.x[0]&&b.x1<=a.x[0]&&b.y1<=a.x[1]&&b.y2>=a.x[1]); } inline void query(int now,sqr tp){ if(!now) return 0; int re=0; if(totalin(now,tp)){ ans+=t[now].sum;return; } if(ptin(t[now].c,tp)) ans+=t[now].c.cnt; int l=ls[now],r=rs[now]; if(chkin(l,tp)) query(l,tp); if(chkin(r,tp)) query(r,tp); return re; }
k远/近询问
构造一个小/大根堆,先push几个0/inf
然后query树更新就行了,用估价函数来判断区间包含和剪枝(决定搜索顺序
复杂度不稳定,没有保证,需要卡常
下面是K远点(曼哈顿距离,我转化成平方避免小数)查询的代码
int dissqr(point tp,int a){ int di=0; for(int i=0;i<=1;i++){ int nd=0; if(tp.x[i]<t[a].mi[i]) nd=t[a].mx[i]-tp.x[i]; else if(tp.x[i]>t[a].mx[i]) nd=tp.x[i]-t[a].mi[i]; else nd=max(tp.x[i]-t[a].mi[i],t[a].mx[i]-tp.x[i]); di+=nd*nd; } return di; } void query(int now,point tp){ int di=get_dis(t[now].c,tp);if(di>q.top()) q.pop(),q.push(di); int l=ls[now],r=rs[now],dl,dr; dl=l?dissqr(tp,l):-inf,dr=r?dissqr(tp,r):-inf; if(dl>dr){ if(dl>q.top()) query(l,tp); if(dr>q.top()) query(r,tp); }else{ if(dr>q.top()) query(r,tp); if(dl>q.top()) query(l,tp); } }
重构
每次insert的时候check一下就可以啦
参考替罪羊树,设一个重构参数
还有就是注意回收节点内存,开个栈
#define alpha 0.75 int rub[N],top; inline int newnode(){ if(top) return rub[top--]; else return ++tot; } inline void clear(int now,int pos){ if(ls[now]) clear(ls[now],pos); p[pos+t[ls[now]].sz+1]=t[now].c,rub[++top]=now; if(rs[now]) clear(rs[now],pos+t[ls[now]].sz+1); } inline void check(int &now,int d){ if(alpha*(double)(t[now].sz)<(double)(t[ls[now]].sz)||alpha*(double)(t[now].sz)<(double)(t[rs[now]].sz)){ clear(now,0);now=build(1,t[now].sz,d); } } inline void insert(int &now,point p,int d){ ... check(now,d); }
完整模板
K远点对
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<queue> #define inf 192608170000000ll #define ll long long using namespace std; long long read(){ long long x=0,pos=1;char ch=getchar(); for(;!isdigit(ch);ch=getchar()) if(ch=='-') pos=0; for(;isdigit(ch);ch=getchar()) x=(x<<1)+(x<<3)+ch-'0'; return pos?x:-x; } const long long N = 200001; long long n,k; struct point{ long long x[2]; }p[N]; struct cmp{ long long operator()(long long a,long long b){ return a>b; } }; priority_queue<long long,vector<long long>,cmp>q; struct node{ long long mi[2],mx[2],sz;point c; }t[N]; long long rt,D,rs[N],ls[N]; long long operator < (point a,point b){ return a.x[D]<b.x[D]; } void push_up(long long now){ long long l=ls[now],r=rs[now]; t[now].sz=t[l].sz+t[r].sz+1; for(register long long i=0;i<=1;i++){ t[now].mi[i]=t[now].mx[i]=t[now].c.x[i]; if(l) t[now].mi[i]=min(t[now].mi[i],t[l].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[l].mx[i]); if(r) t[now].mi[i]=min(t[now].mi[i],t[r].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[r].mx[i]); } } long long tot=0; void build(long long &now,long long l,long long r,long long d){ if(l>r) return; now=++tot; long long mid=(l+r)>>1; D=d;nth_element(p+l,p+mid,p+r+1); t[now].c=p[mid]; build(ls[now],l,mid-1,d^1); build(rs[now],mid+1,r,d^1); push_up(now); } inline long long abs(long long a){ return a>0?a:-a; } long long get_dis(point a,point b){ return (a.x[0]-b.x[0])*(a.x[0]-b.x[0])+(a.x[1]-b.x[1])*(a.x[1]-b.x[1]); } long long dissqr(point tp,long long a){ long long di=0; for(long long i=0;i<=1;i++){ long long nd=0; if(tp.x[i]<t[a].mi[i]){ nd=t[a].mx[i]-tp.x[i]; }else if(tp.x[i]>t[a].mx[i]){ nd=tp.x[i]-t[a].mi[i]; }else nd=max(tp.x[i]-t[a].mi[i],t[a].mx[i]-tp.x[i]); di+=nd*nd; } return di; } void query(long long now,point tp){ long long di=get_dis(t[now].c,tp);if(di>q.top()) q.pop(),q.push(di); long long l=ls[now],r=rs[now],dl,dr; dl=l?dissqr(tp,l):-inf,dr=r?dissqr(tp,r):-inf; if(dl>dr){ if(dl>q.top()) query(l,tp); if(dr>q.top()) query(r,tp); }else{ if(dr>q.top()) query(r,tp); if(dl>q.top()) query(l,tp); } } int main(){ n=read(),k=read(); for(register long long i=1;i<=n;i++){ p[i].x[0]=read(); p[i].x[1]=read(); } build(rt,1,n,0); for(register long long i=1;i<=2*k;i++){ q.push(0); } for(long long i=1;i<=n;i++){ query(rt,p[i]); } /*putchar(10); for(long long i=1;i<=n;i++){ printf("%d %dn",p[i].x[0],p[i].x[1]); } for(long long i=1;i<=n;i++){ for(long long j=1;j<=n;j++){ printf("%d ",get_dis(p[i],p[j])); } putchar(10); }*/ printf("%lld",q.top()); return 0; }
MOKIA(三维数点)
我偏不写CDQ
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #define inf 1926081700 #define alpha 0.75 #define ll long long using namespace std; int read(){ int x=0,pos=1;char ch=getchar(); for(;!isdigit(ch);ch=getchar()) if(ch=='-') pos=0; for(;isdigit(ch);ch=getchar()) x=(x<<1)+(x<<3)+ch-'0'; return pos?x:-x; } const int N = 400001; int n,k,ans,lnk[N],lst[N],rub[N]; struct sqr{ int x1,x2,y1,y2; }q; struct point{ int x[2],cnt; }p[N],pn; struct node{ int mi[2],mx[2],sz,sum;point c; }t[N]; int rt,D,rs[N],ls[N],top,tot; int operator < (point a,point b){ return a.x[D]<b.x[D]; } void push_up(int now){ int l=ls[now],r=rs[now];t[now].sz=t[l].sz+t[r].sz+1;t[now].sum=t[l].sum+t[r].sum+t[now].c.cnt; for(register int i=0;i<=1;i++){ t[now].mi[i]=t[now].mx[i]=t[now].c.x[i]; if(l) t[now].mi[i]=min(t[now].mi[i],t[l].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[l].mx[i]); if(r) t[now].mi[i]=min(t[now].mi[i],t[r].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[r].mx[i]); } } inline int newnode(){ if(top) return rub[top--]; else return ++tot; } inline int build(int l,int r,int d){ if(l>r) return 0; int now=newnode(),mid=(l+r)>>1; D=d,nth_element(p+l,p+mid,p+r+1); t[now].c=p[mid],ls[now]=build(l,mid-1,d^1),rs[now]=build(mid+1,r,d^1); push_up(now); return now; } inline void clear(int now,int pos){ if(ls[now]) clear(ls[now],pos); p[pos+t[ls[now]].sz+1]=t[now].c,rub[++top]=now; if(rs[now]) clear(rs[now],pos+t[ls[now]].sz+1); } inline void check(int &now,int d){ if(alpha*(double)(t[now].sz)<(double)(t[ls[now]].sz)||alpha*(double)(t[now].sz)<(double)(t[rs[now]].sz)){ clear(now,0);now=build(1,t[now].sz,d); } } inline void insert(int &now,point p,int d){ if(!now){ now=newnode();ls[now]=rs[now]=0;t[now].c=p;push_up(now);return; } if(p.x[d]<=t[now].c.x[d]){ insert(ls[now],p,d^1); }else{ insert(rs[now],p,d^1); } push_up(now);check(now,d); } int chkin(int now,sqr tp){ return (!(t[now].mx[0]<tp.x1||t[now].mi[0]>tp.x2||t[now].mx[1]<tp.y1||t[now].mi[1]>tp.y2)); } int totalin(int now,sqr tp){ return (t[now].mx[0]<=tp.x2&&t[now].mi[0]>=tp.x1&&t[now].mx[1]<=tp.y2&&t[now].mi[1]>=tp.y1); } int ptin(point a,sqr b){ return (b.x2>=a.x[0]&&b.x1<=a.x[0]&&b.y1<=a.x[1]&&b.y2>=a.x[1]); } inline int query(int now,sqr tp){ if(!now) return 0; int re=0; if(totalin(now,tp)){ return t[now].sum; }else if(!chkin(now,tp)) return 0; if(ptin(t[now].c,tp)) re+=t[now].c.cnt; int l=ls[now],r=rs[now]; re+= query(l,tp); re+= query(r,tp); return re; } int main(){ int qqq=read(),ppp=read(),opt;//前两个数并没有什么用 while(opt=read()) if(opt==1){ pn.x[0]=(read()),pn.x[1]=(read()),pn.cnt=(read()); insert(rt,pn,0); }else if(opt==2){ q.x1=(read()),q.y1=(read()),q.x2=(read()),q.y2=(read()); ans=query(rt,q);printf("%dn",ans); }else return 0; return 0; }
K-D 树优化建边
NOI 2019考到了所以写一写
竟然1A了。。。(可能是之前一些KDT的题调了好久所以比较熟悉
思路跟线段树的差不多,这题不过空间开不下,所以考虑不保存边
考虑dijkstra算法中每个点只能作为中间节点松弛连的节点一次(vis)
于是建边的复杂度就跟每次直接K-D树上查询复杂度一样啦
具体来说,
- 如果当前点是原来的点,直接上树查询并松弛
- 如果是树上的点,它不可能再向树上区间连边,只连向它的左右儿子和对应的原点
码量也不是很大(还没有splay大),注意细节
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<queue> #define inf 1926081700; using namespace std; int read(){ int x=0,pos=1;char ch=getchar(); for(;!isdigit(ch);ch=getchar()) if(ch=='-') pos=0; for(;isdigit(ch);ch=getchar()) x=(x<<1)+(x<<3)+ch-'0'; return pos?x:-x; } const int N = 75001; struct point{ int x[2],ori; }p[N<<1]; struct node{ int mx[2],mi[2],sz,ord; point c; }t[N<<1]; int ls[N<<1],rs[N<<1]; int n,m,w,h,tot,D; int operator < (point a,point b){ return a.x[D]<b.x[D]; } int operator > (point a,point b){ return a.x[D]>b.x[D]; } inline void push_up(int now){ int l=ls[now],r=rs[now]; t[now].sz=t[l].sz+t[r].sz+1; t[now].mi[0]=t[now].mx[0]=t[now].c.x[0];t[now].mi[1]=t[now].mx[1]=t[now].c.x[1]; if(l) t[now].mi[0]=min(t[now].mi[0],t[l].mi[0]),t[now].mi[1]=min(t[now].mi[1],t[l].mi[1]),t[now].mx[0]=max(t[now].mx[0],t[l].mx[0]),t[now].mx[1]=max(t[now].mx[1],t[l].mx[1]); if(r) t[now].mi[0]=min(t[now].mi[0],t[r].mi[0]),t[now].mi[1]=min(t[now].mi[1],t[r].mi[1]),t[now].mx[0]=max(t[now].mx[0],t[r].mx[0]),t[now].mx[1]=max(t[now].mx[1],t[r].mx[1]); } inline void build(int &now,int l,int r,int d){ if(l>r) return; now=++tot;int mid=(l+r)>>1; D=d;nth_element(p+l,p+mid,p+r+1);t[now].c=p[mid];t[now].ord=p[mid].ori; build(ls[now],l,mid-1,d^1);build(rs[now],mid+1,r,d^1); push_up(now); } struct sqr{ int x1,x2,y1,y2,w; }qu[N<<1]; struct graph{ int v,nex; }edge[N<<1]; int tope=0,head[N],dis[N<<1],vis[N<<1],rt; void add(int u,int v){ edge[++tope].v=v; edge[tope].nex=head[u]; head[u]=tope; } struct type{ int pt,w; }; struct cmp{ int operator()(type a,type b){ return a.w>b.w; } }; priority_queue<type,vector<type>,cmp> q; inline type mk(int a,int b){ type nw;nw.pt=a,nw.w=b;return nw; } inline void relax(int u,int v,int w){ if(dis[v]>dis[u]+w){ dis[v]=dis[u]+w; if(!vis[v]){ q.push(mk(v,dis[v])); } } } inline int totalin(int now,sqr tp){ return (t[now].mi[0]>=tp.x1&&t[now].mx[0]<=tp.x2&&t[now].mi[1]>=tp.y1&&t[now].mx[1]<=tp.y2); } inline int totalout(int now,sqr tp){ return (t[now].mx[0]<tp.x1||t[now].mi[0]>tp.x2||t[now].mx[1]<tp.y1||t[now].mi[1]>tp.y2); } inline int ptin(point now,sqr tp){ return (now.x[0]>=tp.x1&&now.x[0]<=tp.x2&&now.x[1]>=tp.y1&&now.x[1]<=tp.y2); } inline void query(int now,sqr tp,int u){ if(totalin(now,tp)){ relax(u,now,tp.w); return; } if(ptin(t[now].c,tp)) relax(u,t[now].ord,tp.w); int l=ls[now],r=rs[now]; if(!totalout(l,tp)) query(l,tp,u); if(!totalout(r,tp)) query(r,tp,u); } inline void dijkstra(){ q.push(mk(1,0));dis[1]=0; for(int i=2;i<=tot;i++){ dis[i]=inf; } while(!q.empty()){ int now=q.top().pt;q.pop(); if(vis[now]) continue;else vis[now]=1; if(now<=n){ for(int i=head[now];i;i=edge[i].nex){ int v=edge[i].v; query(rt,qu[v],now); } }else{ relax(now,ls[now],0); relax(now,rs[now],0); relax(now,t[now].ord,0); } } for(int i=2;i<=n;i++){ printf("%dn",dis[i]); } } int main(){ n=read(),m=read(),w=read(),h=read(); for(int i=1;i<=n;i++){ p[i].x[0]=read(),p[i].x[1]=read(),p[i].ori=i; } tot=n; build(rt,1,n,1); for(int i=1;i<=m;i++){ int u=read(); qu[i].w=read(),qu[i].x1=read(),qu[i].x2=read(),qu[i].y1=read(),qu[i].y2=read(); add(u,i); } dijkstra(); return 0; }
后记
感觉数据结构也学的差不多了吧。。。
之后可能会写的数据结构博客:
top-tree/李超线段树/势能线段树/毒瘤分块题