link http://uoj.ac/problem/388
题意:
给定一棵 n 个点的树, 每条边有权值, 树上两点路径长度定义为边权和. 给定一个元素在 [1,n] 的长为 m 的序列, 求出对于每个长为偶数的区间, 区间中的数字两两匹配后每对点的路径长度之和最小值. 输出所有长为偶数区间的这个最小值之和.
$n,m\leq 10^5.$
题解:
转化很巧妙.
直接算很不好算, 考虑计算每条边的贡献. 一个性质是: 假定在区间中的元素集合为 S, 对于某一条边分成的两个子树, 如果两个子树中出现在 S 中的元素个数均为奇数, 则这条边有 1 的贡献, 否则没有贡献.
证明很简单, 对于都为偶数的情况考虑反证, 如果存在路径经过这条边, 则一定至少两 (偶数) 条, 那么可以把这两条路径都删去这条边得到更优解. 对于都为奇数的情况, 一定至少存在一条经过这条边的路径, 去掉这条路径后则转化为了偶数的情况. 证毕.
那么原题转化为: 对于每个子树, 如果将子树中的元素在序列中标记为 1, 那么要求的就是这个 01 串中有多少长为偶数的区间内 1 的个数为奇数.
暴力算是 $\mathcal{O}(nm)$ 的. 我们考虑用线段树维护 01 串, 记录区间内 1 的个数, 区间内位置为奇 / 偶, 前缀和 mod2 为奇 / 偶的下标数量. 线段树合并即可. 复杂度 $\mathcal{O}(n\log m)$.
code:
- #include<bits/stdc++.h>
- #define rep(i,x,y) for (int i=(x);i<=(y);i++)
- #define ll long long
- #define inf 1000000001
- #define y1 y1___
- using namespace std;
- ll read(){
- char ch=getchar();ll x=0;int op=1;
- for (;!isdigit(ch);ch=getchar()) if (ch=='-') op=-1;
- for (;isdigit(ch);ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
- return x*op;
- }
- #define N 100005
- #define M 2000005
- #define mod 998244353
- int n,m,cnt,tot,ans,head[N],rt[N],ls[M],rs[M],sum[M],a[M][2][2];
- struct edge{int to,nxt,v;}e[N<<1];
- void adde(int x,int y,int z){
- e[++cnt].to=y;e[cnt].nxt=head[x];head[x]=cnt;
- e[cnt].v=z;
- }
- void up(int k,int l,int r){
- sum[k]=0;
- if (ls[k]) sum[k]+=sum[ls[k]];
- if (rs[k]) sum[k]+=sum[rs[k]];
- int x=ls[k]?sum[ls[k]]&1:0;
- rep (i,0,1) rep (j,0,1){
- a[k][i][j]=0;
- if (ls[k]) a[k][i][j]+=a[ls[k]][i][j];
- if (rs[k]) a[k][i][j]+=a[rs[k]][i^x][j];
- }
- int mid=l+r>>1;// 注意这两句别忘
- if (!ls[k]) a[k][0][0]+=mid/2-(l-1)/2,a[k][0][1]+=(mid+1)/2-l/2;
- if (!rs[k]) a[k][x][0]+=r/2-mid/2,a[k][x][1]+=(r+1)/2-(mid+1)/2;
- }
- void ins(int &k,int l,int r,int x){
- if (!k){// 注意赋初始值
- k=++tot;
- a[k][0][0]=r/2-(l-1)/2;
- a[k][0][1]=(r+1)/2-l/2;
- }
- if (l==r){sum[k]++;return;}
- int mid=l+r>>1;
- if (x<=mid) ins(ls[k],l,mid,x);else ins(rs[k],mid+1,r,x);
- up(k,l,r);
- }
- int merge(int x,int y,int l,int r){
- if (!x||!y) return x|y;
- int mid=l+r>>1;
- ls[x]=merge(ls[x],ls[y],l,mid);
- rs[x]=merge(rs[x],rs[y],mid+1,r);
- up(x,l,r);
- return x;
- }
- void upd(int &x,int y){x+=y;x-=x>=mod?mod:0;}
- void dfs(int u,int pr){
- for (int i=head[u];i;i=e[i].nxt) if (e[i].to!=pr){
- int v=e[i].to;
- dfs(v,u);
- upd(ans,((ll)a[rt[v]][0][0]*a[rt[v]][1][0]%mod+(ll)a[rt[v]][0][1]*a[rt[v]][1][1]%mod)%mod*e[i].v%mod);
- rt[u]=merge(rt[u],rt[v],1,m+1);
- }
- }
- int main(){
- n=read(),m=read();
- rep (i,1,n-1){
- int x=read(),y=read(),z=read();
- adde(x,y,z);adde(y,x,z);
- }
- rep (i,1,m) ins(rt[read()],1,m+1,i);
- dfs(1,0);
- cout<<ans<<'\n';
- return 0;
- }
- View Code
易错:
注意初始情况不是 0, 需要处理.
来源: http://www.bubuko.com/infodetail-2685422.html