考虑定义 $ F(x) $ 为 \(x\) 为根的子树所有点与 $ x $ 的深度差 (其实就是 $ x $ 到每个子树内点的距离) 的 1 的个数和.
注意,$ F(x) $ 的值不是答案, 但是只需要一点树形 dp 的基础内容就可以变成要求的答案.
对于一个点 $ u $ , 考虑它的一个儿子 $ v $ , 我们此时已经计算出了 $ F( v ) $ 的值那么怎么统计 $ v $ 中所有点对于 $ u $ 的贡献呢? 首先考虑 $ F(v) $ 的变化, 由于当前的点 $ u $ 是 $ v $ 的父亲,$ v $ 中所有点到 $ u $ 的距离实际上是原来到 $ v $ 的路径长度 + 1. 那么二进制中 1 的个数加了多少呢?
对于一个 $ v $ 子树中点 $ k $, 假设它到 $ v $ 的距离是 $ d $, 则:
如果 $ d \equiv 0 \pmod 2 $ 那么显然二进制 1 的个数直接 + 1
如果 $ d \equiv 1 \pmod 2 $ 那么二进制中 1 的个数 不变
如果 $ d \equiv 3 \pmod {2^2} $ 那么二进制中 1 的个数 少 1
如果 $ d \equiv 7 \pmod {2^3} $ 那么二进制中 1 的个数 少 1
...
那么就有了一个思路, 把 $ v $ 子树中与 $ v $ 距离 $ d \equiv {2^k - 1} \pmod {2^k} $ 的点的个数存着, 这个可以倍增预处理.
那么对于 $ F $ 我们就会转移了, 先 + 上子树的 size, 然后减去 $ v $ 子树中 $ 2^k - 1 $ 距离的点的个数.
转移了 $ F $ 后, 直接给 $ v $ 中的 $ F $ 乘上 $ size(u) - size(v) $ (这个是显然的树形 dp 了)
- #include<iostream>
- #include<cstring>
- #include<cstdio>
- #include<algorithm>
- using namespace std;
- #define MAXN 100006
- typedef long long ll;
- int n;
- int read( ) {
- int ret = 0; char ch = ' ';
- while( ch <'0' || ch> '9' ) ch = getchar();
- while( ch>= '0' && ch <= '9' ) ret *= 10 , ret += ch - '0' , ch = getchar();
- return ret;
- }
- int head[MAXN] , to[MAXN << 1] , nex[MAXN << 1] , ecn = 0;
- void ade( int u , int v ) {
- to[++ecn] = v , nex[ecn] = head[u] , head[u] = ecn;
- }
- int G[MAXN][18] , GG[MAXN][18]; ll t[MAXN][18] ; // G 2^k , GG 2^{k - 1} , t how many nodes at dep % 2^k = 2^k - 1
- int siz[MAXN];
- void dfs( int u , int fa ) {
- siz[u] = 1;
- for( int i = head[u] ; i ; i = nex[i] ) {
- int v = to[i];
- if( v == fa ) continue;
- G[v][0] = u , GG[v][0] = v;
- for( int k = 1 ; k < 18 ; ++ k ) {
- if( G[G[v][k-1]][k-1] )
- G[v][k] = G[G[v][k-1]][k-1];
- if( G[GG[v][k-1]][k-1] )
- GG[v][k] = G[GG[v][k-1]][k-1];
- else break;
- }
- dfs( v , u );
- siz[u] += siz[v];
- }
- for( int k = 1 ; k < 18 ; ++ k ) {
- if( G[u][k] )
- t[G[u][k]][k] += t[u][k];
- if( GG[u][k] )
- ++ t[GG[u][k]][k];
- else break;
- }
- }
- ll res = 0;
- ll T[MAXN];
- ll solve( int u , int fa ) {
- ll R = 0 , ret = 0;
- for( int i = head[u] ; i ; i = nex[i] ) {
- int v = to[i];
- if( v == fa ) continue;
- R = 0;
- ll lst = solve( v , u );
- R += lst + siz[v];
- R -= T[v];
- res += R * ( siz[u] - siz[v] );
- ret += R;
- }
- return ret;
- }
- signed main( ) {
- n = read();
- for( int i = 1 , u , v ; i < n ; ++ i ) {
- u = read() , v = read();
- ade( u , v ) , ade( v , u );
- }
- dfs( 1 , 1 );
- for( int i = 1 ; i <= n ; ++ i )
- for( int k = 1 ; k < 18 ; ++ k )
- T[i] += t[i][k];
- solve( 1 , 1 );
- printf("%lld",res);
- }
来源: http://www.bubuko.com/infodetail-3195295.html