线段树维护括号序列
对树进行 dfs, 入栈时加一个左括号, 出栈时加一个右括号, 那么书上两点间的距离 = 括号序列两点间不匹配括号数
例:
树 1--2--3,2 为根
括号序列为 (2(3)(1))
2 和 1 的距离 为 ()( = 1, 3 和 1 的距离为 )( =2
具体怎么维护不想写了, 去看曹钦翔的冬令营讲稿数据结构的提炼与压缩 (p2930) 吧
- #include < cstdio > #include < iostream > #define N 100001 using namespace std;
- int front[N],
- nxt[N << 1],
- to[N << 1],
- tot;
- int id[N];
- int num[N * 3];
- bool light[N];#define max(a, b)((a) > (b) ? (a) : (b)) struct node {
- int a,
- b,
- dis;
- int right_plus,
- right_minus,
- left_plus,
- left_minus;
- void get_val(int x) {
- a = b = 0;
- right_plus = right_minus = left_plus = left_minus = dis = -1e7;
- if (num[x] == -1) b = 1;
- else if (num[x] == -2) a = 1;
- else if (!light[num[x]]) right_plus = right_minus = left_plus = left_minus = dis = 0;
- }
- node operator + (node p) {
- node k;
- k.a = max(a, a + p.a - b);
- k.b = max(p.b, p.b + b - p.a);
- k.dis = max(max(dis, p.dis), max(right_plus + p.left_minus, right_minus + p.left_plus));
- int A = a,
- B = b,
- C = p.a,
- D = p.b;
- k.right_plus = max(p.right_plus, max(right_plus + D - C, right_minus + D + C));
- k.right_minus = max(p.right_minus, right_minus + C - D);
- k.left_plus = max(left_plus, max(A + B + p.left_minus, A - B + p.left_plus));
- k.left_minus = max(left_minus, B - A + p.left_minus);
- return k;
- }
- }
- tr[N * 3 << 2];
- void read(int & x) {
- x = 0;
- char c = getchar();
- while (!isdigit(c)) c = getchar();
- while (isdigit(c)) {
- x = x * 10 + c - 0;
- c = getchar();
- }
- }
- void add(int u, int v) {
- to[++tot] = v;
- nxt[tot] = front[u];
- front[u] = tot;
- to[++tot] = u;
- nxt[tot] = front[v];
- front[v] = tot;
- }
- void dfs(int x, int fa) {
- num[++tot] = -1;
- id[num[++tot] = x] = tot;
- for (int i = front[x]; i; i = nxt[i]) if (to[i] != fa) dfs(to[i], x);
- num[++tot] = -2;
- }
- void build(int k, int l, int r) {
- if (l == r) {
- tr[k].get_val(l);
- return;
- }
- int mid = l + r >> 1;
- build(k << 1, l, mid);
- build(k << 1 | 1, mid + 1, r);
- tr[k] = tr[k << 1] + tr[k << 1 | 1];
- }
- void change(int k, int l, int r, int pos) {
- if (l == r) {
- tr[k].get_val(l);
- return;
- }
- int mid = l + r >> 1;
- if (pos <= mid) change(k << 1, l, mid, pos);
- else change(k << 1 | 1, mid + 1, r, pos);
- tr[k] = tr[k << 1] + tr[k << 1 | 1];
- }
- int main() {
- int n,
- u,
- v;
- read(n);
- for (int i = 1; i < n; ++i) {
- read(u);
- read(v);
- add(u, v);
- }
- tot = 0;
- dfs(1, 0);
- build(1, 1, tot);
- int m;
- char s[3];
- read(m);
- int cnt = n;
- while (m--) {
- scanf("%s", s);
- if (s[0] == G) {
- if (cnt <= 1) printf("%d\n", cnt - 1);
- else printf("%d\n", tr[1].dis);
- } else {
- read(u);
- if (light[u]) cnt++;
- else cnt--;
- light[u] ^= 1;
- change(1, 1, tot, id[u]);
- }
- }
- return 0;
- }
来源: http://www.bubuko.com/infodetail-2508552.html