模板, 不解释
- #include<bits/stdc++.h>
- using namespace std;
- const int mxn=1e5+5;
- int fa[mxn],ch[mxn][2],sz[mxn],cnt[mxn],val[mxn],rt,tot;
- namespace Splay {
- void push_up(int x) {
- sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
- };
- void rotate(int x) {
- int y=fa[x],z=fa[y],tp=ch[y][1]==x;
- ch[z][ch[z][1]==y]=x,fa[x]=z; // 这里容易写错
- ch[y][tp]=ch[x][tp^1],fa[ch[x][tp^1]]=y;
- ch[x][tp^1]=y,fa[y]=x;
- push_up(y),push_up(x);
- };
- void splay(int x,int gl) {
- while(fa[x]!=gl) {
- int y=fa[x],z=fa[y];
- if(z!=gl)
- (ch[y][1]==x)^(ch[z][1]==y)?rotate(x):rotate(y);
- rotate(x);
- }
- if(gl==0) rt=x;
- };
- void find(int x) {
- int u=rt;
- while(ch[u][x>val[u]]/* 这里不一定 find 的到该值, 所以一定要加这句话 */&&x!=val[u]) u=ch[u][x>val[u]];
- splay(u,0);
- };
- int kth(int k) {
- int u=rt;
- while(1) {
- if(k<=sz[ch[u][0]]) u=ch[u][0];
- else if(k>sz[ch[u][0]]+cnt[u]) k-=sz[ch[u][0]]+cnt[u],u=ch[u][1];
- else return u;
- }
- };
- void ins(int x) {
- int u=rt,f=0;
- while(val[u]!=x&&u) f=u,u=ch[u][x>val[u]];
- if(u==0) {
- u=++tot;
- if(f) ch[f][x>val[f]]=u;
- val[u]=x; fa[u]=f;
- cnt[u]=sz[u]=1;
- }
- else ++cnt[u];
- splay(u,0);
- };
- int pre(int x) {
- find(x);
- if(val[rt]<x) return rt;
- int u=ch[rt][0];
- while(ch[u][1]) u=ch[u][1];
- return u;
- };
- int nxt(int x) {
- find(x);
- if(val[rt]>x) return rt;
- int u=ch[rt][1];
- while(ch[u][0]) u=ch[u][0];
- return u;
- };
- void erase(int x) {
- find(x);
- if(cnt[rt]>1) --cnt[rt];
- else {
- int l=pre(x),r=nxt(x); // 这里容易写错
- splay(l,0); splay(r,l);
- ch[r][0]=0;
- }
- };
- }
- int main()
- {
- using namespace Splay;
- int t,opt,x;
- scanf("%d",&t);
- ins(-1000000000),ins(1000000000);// 切记插入端点, 否则前驱后继不好求
- while(t--) {
- scanf("%d %d",&opt,&x);
- if(opt==1) ins(x);
- else if(opt==2) erase(x);
- else if(opt==3) find(x),printf("%d\n",sz[ch[rt][0]]);
- else if(opt==4) printf("%d\n",val[kth(x+1)]);
- else if(opt==5) printf("%d\n",val[pre(x)]);
- else printf("%d\n",val[nxt(x)]);
- }
- return 0;
- }
[P3369] 普通平衡树 (Splay 版)
来源: http://www.bubuko.com/infodetail-2950618.html