题目描述
Luogu https://www.luogu.org/problemnew/show/P4365
题目大意: 给一棵 \(n\)个点的树, 求所有联通块中第 \(K\)大的权值 \(W_k\)之和.
数据范围:\(K\leq n\leq 1666\) , \(W_{max}\leq 1666\), 答案对 \(64123\)取模, 时限 \(7sec\).
题解
- \(Ans = \sum_{
- S
- } Kth\ of\ S = \sum_{
- v = 1
- }^W v\sum_{
- S
- } [Kth\ of\ S\ = v]\)
- \(Ans = \sum_{
- v = 1
- }^W \sum_{
- S
- } [Kth\ of\ S \ge v]\)
我们令 \(cnt(S,v)\)表示 \(S\)中权值大于等于 \(v\)的节点个数.
\(Ans = \sum_{v=1}^W \sum_{S} [cnt(S,v)\ge K]\)
然后就可以设计一个 \(dp\)了, 设 \(f_{u,v,j}\)表示 \(u\)为根的联通块中,\(W\ge v\)联通块数.
转移显然:\(f_{u,v,j} = \prod_{son_i} f_{son_i,v,k_{son_i}}\), 其中 \(\sum_{son_i} k_{son_i} = j - [W_u\ge v]\).
根据上述可得:\(Ans = \sum_{u} \sum_{v=1}^W \sum_{j=K}^n f_{u,v,j}\).
卡一下 \(j\)那一维, 用树上 \(lca\)那套分析一下, 复杂度就是严格 \(O(n^2W)\)的.
然后竟然就能成功 AC 原题数据了 qaq......
不管了.
可以注意到 \(j\)那一维是一个背包, 所以自然就能想到生成函数.
设 \(F_{u,v} = \sum_{j=0}^n f_{u,v,j} x^j\), 设 \(G_{u,v} = \sum_{s\in Tree_u} F_{s,v} = \sum_{j=0}^n g_{u,v,j}x^j\), 那么 \(F\)的转移就是一个卷积了.
即初始化后,\(F_{u,v} = \prod_{son_i} F_{son_i,v}\).
\(Ans = \sum_{u} \sum_{v=1}^W \sum_{j=K}^n f_{u,v,j} = \sum_{v=1}^W \sum_{j=K}^n g_{1,v,j} = \sum_{j=K}^n \sum_{v=1}^W g_{1,v,j}\).
我们的目标即求 \(\sum_{v=1}^W G_{1,v}\)的每一项系数.
每次转移都卷积显然是傻子.
熟悉 \(FFT\)原理的童鞋都知道先用点值表示, 最后再拉格朗日插值回去即可得到每一项的系数.
外部枚举 \(x = 1,2...n+1\), 下面我们来考虑如何计算 \(F\),\(G\)的点值表示.
注意到由于转了点值表示, 所以多项式乘法是对位乘法, 就可以用线段树合并维护了.
线段树每个叶子节点 \(v\)维护点值 \(F_{u,v},G_{u,v}\), 我们在每个点 \(u\)要干这些事:
初始化: 把区间 \([1,W_u]\)的 \(F\)加上 \(x\), 把区间 \((W_u,W_{max}]\)的 \(F\)加上 \(1\).
合并: 把 \(F_{son_i}\)对位相乘,\(G_{son_i}\)对位相加.
结束: 把 \(G_{u}\)加上 \(F_u\), 把 \(F_u\)加 \(1\), 便于下次转移.
维护一个标记 \((a,b,c,d)\), 表示 \((F,G)\ \to\ F(aF + b , cF + d + G)\).
那么初始化对应标记 \((1,x,0,0)\),\((1,1,0,0)\). 结束对应标记 \((1,1,1,0)\).
合并的时候, 线段树合并, 设合并 \(x\),\(y\).
若其中一个点 (以 \(y\) 为例)没有儿子了, 也就是说下面的节点的 \((F,G)\)都是一样的了.
此时 \(F_y = b\),\(G_y = d\), 对应 \((F_x,G_x)\to (F_xF_y,G_x+G_y)\), 修改 \(x\)的标记, 然后 \(return\)即可.
最后把根节点的线段树遍历一遍, 标记都放下去后叶子节点 \(v\)的标记中的 \(d\)即 \(G_{1,v}\).
最后套一下拉格朗日公式:\(H(x) = \sum_{i=1}^{n+1} H(i) \prod_{j\neq i} \frac{x-j}{i-j}\).
你说每次算 \(\prod(x - j)\)是 \(O(n^2)\)的?
多项式除法了解一下蟹蟹 qwq...... 先别管 \(j\neq i\)就行了.
复杂度 \(O(n^2logW)\), 被暴力吊着打, 代码其实挺短的.
实现代码
- #include<bits/stdc++.h>
- #define IL inline
- #define _ 2005
- #define ll long long
- #define ld long double
- using namespace std ;
- IL ll gi(){
- ll data = 0 , m = 1; char ch = 0 ;
- while((ch != '-') && (ch <'0' || ch> '9')) ch = getchar() ;
- if(ch == '-'){m = 0 ; ch = getchar() ; }
- while(ch>= '0' && ch <= '9'){data = (data<<1) + (data<<3) + (ch^48) ; ch = getchar() ; }
- return (m) ? data : -data ;
- }
- #define mod 64123
- int n , K , W , oo , stk[_ * _] , ans[_] , f[_] , fz[_] , H[_] , Ans , inv[mod] , rt[_] , val[_] ;
- struct _Edge{
- int to , next ;
- }Edge[_ <<1] ; int head[_] , CNT ;
- IL void AddEdge(int u , int v) {
- Edge[++ CNT] = (_Edge){v , head[u]} ; head[u] = CNT ; return ;
- }
- struct Target {
- int a , b , c , d ;
- IL Target() {a = 1 ; b = 0 ; c = 0 ; d = 0 ; }
- IL Target(int s1,int s2,int s3,int s4) {a = s1 ; b = s2 ; c = s3 ; d = s4 ; }
- } ;
- IL Target operator + (Target A , Target B) {
- Target C ;
- C.a = 1ll * A.a * B.a % mod ; C.b = (1ll * B.a * A.b % mod + B.b) % mod ;
- C.c = (1ll * A.a * B.c % mod + A.c) % mod ;
- C.d = (1ll * A.b * B.c % mod + A.d + B.d) % mod ;
- return C ;
- }
- struct Node {
- int ls , rs ; Target tag ;
- IL Node(){ls = rs = 0 ; tag = Target() ; return ; }
- }t[_ * _] ;
- IL int NewNode() {
- if(stk[0]) {t[stk[stk[0]]] = Node() ; return stk[stk[0] --] ; }
- else {t[++oo] = Node() ; return oo ; }
- }
- void PushDown(int o) {
- if(!t[o].ls) t[o].ls = NewNode() ; if(!t[o].rs) t[o].rs = NewNode() ;
- t[t[o].ls].tag = t[t[o].ls].tag + t[o].tag ;
- t[t[o].rs].tag = t[t[o].rs].tag + t[o].tag ;
- t[o].tag = Target() ;
- return ;
- }
- void Insert(int &o , int l , int r , int ql , int qr , Target E) {
- if(!o) o = NewNode() ; if(ql <= l && r <= qr) {t[o].tag = t[o].tag + E ; return ; }
- int mid = (l + r)>> 1 ;
- PushDown(o) ;
- if(ql <= mid) Insert(t[o].ls , l , mid , ql , qr , E) ;
- if(qr> mid) Insert(t[o].rs , mid + 1 , r , ql , qr , E) ;
- return ;
- }
- int Merge(int o , int os) {
- if(!o || !os) return o + os ;
- if(!t[o].ls && !t[o].rs) swap(o , os) ;
- if(!t[os].ls && !t[os].rs) {
- t[o].tag.a = 1ll * t[o].tag.a * t[os].tag.b % mod ;
- t[o].tag.b = 1ll * t[o].tag.b * t[os].tag.b % mod ;
- t[o].tag.d = (t[o].tag.d + t[os].tag.d) % mod ;
- stk[++stk[0]] = os ;
- return o ;
- }
- PushDown(o) ; PushDown(os) ;
- t[o].ls = Merge(t[o].ls , t[os].ls) ;
- t[o].rs = Merge(t[o].rs , t[os].rs) ;
- stk[++stk[0]] = os ;
- return o ;
- }
- void Dfs(int u , int From , int x) {
- Insert(rt[u] , 1 , W , 1 , val[u] , Target(1 , x , 0 , 0)) ;
- if(val[u] + 1 <= W) Insert(rt[u] , 1 , W , val[u] + 1 , W , Target(1 , 1 , 0 , 0)) ;
- for(int e = head[u] ; e ; e = Edge[e].next) {
- int v = Edge[e].to ; if(v == From) continue ;
- Dfs(v , u , x) ;
- rt[u] = Merge(rt[u] , rt[v]) ;
- }
- t[rt[u]].tag = t[rt[u]].tag + Target(1 , 1 , 1 , 0) ; return ;
- }
- void GetAns(int o , int l , int r , int x) {
- if(l == r) {H[x] = (H[x] + t[o].tag.d) % mod ; return ; }
- PushDown(o) ;
- int mid = (l + r)>> 1 ;
- GetAns(t[o].ls , l , mid , x) ; GetAns(t[o].rs , mid + 1 , r , x) ;
- return ;
- }
- IL void Solve(int x) {
- oo = 0 ; stk[0] = 0 ; for(int i = 1; i <= n; i ++) rt[i] = 0 ;
- Dfs(1 , 0 , x) ;
- H[x] = 0 ; GetAns(rt[1] , 1 , W , x) ;
- }
- IL void Lagrange() {
- inv[0] = 1 ; inv[1] = 1 ; for(int i = 2; i <mod; i ++) inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod ;
- fz[0] = 1 ;
- for(int i = 1; i <= n + 1; i ++)
- for(int j = n + 1; j>= 0; j --)
- if(j) fz[j] = (fz[j - 1] + 1ll * fz[j] * (mod - i) % mod) % mod ; else fz[j] = 1ll * fz[j] * (mod - i) % mod ;
- for(int i = 1; i <= n + 1; i ++) {
- int coef = 1 ;
- for(int j = 1; j <= n + 1; j ++) if(i != j) coef = 1ll * coef * (i + mod - j) % mod ;
- for(int j = 0; j <= n + 1; j ++) f[j] = fz[j] ;
- for(int j = 0; j <= n + 1; j ++) {
- if(j) f[j] = (f[j] - f[j - 1] + mod) % mod ;
- f[j] = 1ll * inv[mod - i] * f[j] % mod ;
- }
- coef = 1ll * inv[coef] * H[i] % mod ;
- for(int j = 0; j <= n; j ++) ans[j] = (ans[j] + 1ll * coef * f[j] % mod) % mod ;
- }
- return ;
- }
- int main() {
- n = gi() ; K = gi() ; W = gi() ;
- for(int i = 1; i <= n; i ++) val[i] = gi() ;
- for(int i = 1,u,v; i < n; i ++) u = gi() , v = gi() , AddEdge(u , v) , AddEdge(v , u) ;
- Solve(1) ;
- for(int i = 1; i <= n + 1; i ++) Solve(i) ;
- //for(int i = 1; i <= n + 1; i ++) cout << "H("<<i<<") =" << H[i] << endl ;
- Lagrange() ;
- Ans = 0 ;
- for(int j = K; j <= n; j ++) Ans = (Ans + ans[j]) % mod ;
- cout << Ans << endl ;
- return 0 ;
- }
所以所谓的整体 DP 到底是啥啊, 根本没看到什么虚树的影子啊?
[九省联考 2018]秘密袭击 coat
来源: http://www.bubuko.com/infodetail-2947176.html