记录每一次分治时, 达到重心的 dep, 加到一个桶中, 没添加一个节点, 查询桶中是否存在, 之后加和和答案取最小值
- #include <cstdio>
- #include <algorithm>
- #include <cmath>
- #include <cstring>
- #include <queue>
- #include <iostream>
- #include <cstdlib>
- using namespace std;
- #define N 300505
- #define ll long long
- struct node
- {
- int to,next,val;
- }e[N<<1];
- int siz[N],dep[N],ans=1<<30,v[N*5],rot,sum,vis[N],mx[N],head[N],cnt,K,n;
- ll dis[N];
- void add(int x,int y,int z)
- {
- e[cnt].to=y;
- e[cnt].next=head[x];
- e[cnt].val=z;
- head[x]=cnt++;
- return ;
- }
- void get_root(int x,int from)
- {
- siz[x]=1;mx[x]=0;
- for(int i=head[x];i!=-1;i=e[i].next)
- {
- int to1=e[i].to;
- if(!vis[to1]&&to1!=from)
- {
- get_root(to1,x);
- siz[x]+=siz[to1];
- mx[x]=max(siz[to1],mx[x]);
- }
- }
- mx[x]=max(mx[x],sum-siz[x]);
- if(mx[x]<mx[rot])rot=x;
- }
- void insert(int x,int from)
- {
- if(dis[x]<=K)v[dis[x]]=min(v[dis[x]],dep[x]);
- for(int i=head[x];i!=-1;i=e[i].next)
- {
- int to1=e[i].to;
- if(!vis[to1]&&to1!=from)
- {
- insert(to1,x);
- }
- }
- }
- void clear(int x,int from)
- {
- if(dis[x]<=K)v[dis[x]]=1<<30;
- for(int i=head[x];i!=-1;i=e[i].next)
- {
- int to1=e[i].to;
- if(!vis[to1]&&to1!=from)
- {
- clear(to1,x);
- }
- }
- }
- void calc(int x,int from)
- {
- if(dis[x]<=K)ans=min(ans,dep[x]+v[K-dis[x]]);
- for(int i=head[x];i!=-1;i=e[i].next)
- {
- int to1=e[i].to;
- if(to1!=from&&!vis[to1])
- {
- dep[to1]=dep[x]+1;
- dis[to1]=dis[x]+e[i].val;
- calc(to1,x);
- }
- }
- }
- void dfs(int x)
- {
- vis[x]=1;v[0]=0;
- for(int i=head[x];i!=-1;i=e[i].next)
- {
- int to1=e[i].to;
- if(!vis[to1])
- {
- dep[to1]=1,dis[to1]=e[i].val;
- calc(to1,0);
- insert(to1,0);
- }
- }
- for(int i=head[x];i!=-1;i=e[i].next)
- {
- int to1=e[i].to;
- if(!vis[to1])
- {
- clear(to1,0);
- }
- }
- for(int i=head[x];i!=-1;i=e[i].next)
- {
- int to1=e[i].to;
- if(!vis[to1])
- {
- sum=siz[to1];rot=0;
- get_root(to1,0);
- dfs(rot);
- }
- }
- }
- int main()
- {
- memset(head,-1,sizeof(head));
- scanf("%d%d",&n,&K);
- for(int i=1;i<n;i++)
- {
- int x,y,z;
- scanf("%d%d%d",&x,&y,&z);
- add(x+1,y+1,z);
- add(y+1,x+1,z);
- }
- for(int i=1;i<=K;i++)v[i]=1<<30;
- mx[0]=sum=n;
- get_root(1,0);
- dfs(rot);
- if(ans==1<<30)
- {
- puts("-1");
- return 0;
- }
- printf("%d\n",ans);
- return 0;
- }
来源: http://www.bubuko.com/infodetail-2572047.html