题目描述
Imakf 是一个小蒟蒻, 他最近刚学了 LCA, 他在手机 App 里看到一个游戏也叫做 LCA 就下载了下来.
这个游戏会给出你一棵树, 这棵树有 N 个节点, 根结点是 R, 系统会选中 M 个点 P1,P2...PM, 要 Imakf 回答有多少组点对 (ui,vi) 的最近公共祖先是 Pi.Imakf 是个小蒟蒻, 他就算学了 LCA 也做不出, 于是只好求助您了.
Imakf 毕竟学过一点 OI, 所以他允许您把答案模 (10^9+7)
输入格式
第一行 N , R , M
此后 N-1 行 每行两个数 a,b 表示 a,b 之间有一条边
此后 1 行 M 个数 表示 P_i
输出格式
M 行, 每行一个数, 第 i 行的数表示有多少组点对 (u_i,v_i) 的最近公共祖先是 P_i
N≤10000,M≤50000
显然这题根 LCA 没有多大关系......
设 size(x)表示以 x 为根的子树大小(x 自己也要算), 我们来愉快地推式子
设存在点 u,v, 并且 x 是 u,v 的祖先. 考虑 u,v 之间的路径 (不包含 u,v) 不经过 x, 那么根据往日做 LCA 的经验, 只有当 u,v 至少其中一个等于 x 时两个点的 LCA 才会是 x. 设 x 一共有 k 棵子树, 那么此时:
\[ ans1=\sum_{i=1}^{k}size[son[i]]*2+1=size[x]*2-1 \]
再考虑经过 x 的情况, 此时:
- \[ ans2=\sum_{
- i=1
- }^{
- k
- }\sum_{
- j=1
- }^{
- k
- }size[son[i]]*size[son[j]] \]
- \[ ans2=\sum_{
- i=1
- }^{
- k
- }size[son[i]]*(size[x]-1) \]
- \[ ans2=(size[x]-1)^2 \]
然后减去重复计算的 i=j 的部分:
\[ ans2=(size[x]-1)^2-\sum_{i=1}^{k}size[i]^2 \]
再把两个答案加起来:
- \[ ans=ans1+ans2=size[x]*2-1+(size[x]-1)^2-\sum_{
- i=1
- }^{
- k
- }size[i]^1 \]
- \[ ans=size[x]^2-\sum_{
- i=1
- }^{
- k
- }size[i]^2 \]
然后我们来分析复杂度. 最坏的情况就是: 根直接连接其余所有点, 并且每次询问都是根节点. 此时时间复杂度就是 O(N*M). 考虑优化.
显然重复计算过的我们不需要再算. 记录 ans 数组, 预处理出每个点的答案, 时间复杂度就变成了 O(N+M), 期望得分 100.
- #include<iostream>
- #include<cstring>
- #include<cstdio>
- #define maxn 10001
- #define p 1000000007
- using namespace std;
- struct edge{
- int to,next;
- edge(){}
- edge(const int &_to,const int &_next){ to=_to,next=_next; }
- }e[maxn<<1];
- int head[maxn],k;
- int size[maxn],ans[maxn];
- int n,m,r;
- inline int read(){
- register int x(0),f(1); register char c(getchar());
- while(c<'0'||'9'<c){ if(c=='-') f=-1; c=getchar(); }
- while('0'<=c&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
- return x*f;
- }
- inline void add(const int &u,const int &v){
- e[k]=edge(v,head[u]);
- head[u]=k++;
- }
- inline void dfs(int u,int pre){
- size[u]=1;
- for(register int i=head[u];~i;i=e[i].next){
- int v=e[i].to;
- if(v==pre) continue;
- dfs(v,u),size[u]+=size[v];
- ans[u]=(ans[u]+size[v]*size[v])%p;
- }
- ans[u]=(size[u]*size[u]%p-ans[u]+p)%p;
- }
- int main(){
- memset(head,-1,sizeof head);
- n=read(),r=read(),m=read();
- for(register int i=1;i<n;i++){
- int u=read(),v=read();
- add(u,v),add(v,u);
- }
- dfs(r,0);
- while(m--) printf("%d\n",ans[read()]);
- return 0;
- }
* 相减的部分取余需要判负数...... 或者直接加个 p 上去
来源: http://www.bubuko.com/infodetail-3052271.html