题意: 若两个字符开始的后面 r 个字符都一样, 则称这两个字符是 r 相似的. 它们也是 r-1 相似的.
对于 r∈[0,n) 分别求有多少种方案, 其中权值最大方案权值是多少. 此处权值是选出的两个字符的权值之积.
解: 后缀自动机吊打后缀数组!!!
先看第一问, 我们考虑后缀自动机上每个节点的贡献. 显然 cnt>1 的节点才会有贡献.
它会对 r ∈ len[fail[x]] + 1 ~ len[x] 这一段的答案产生 C(cntx,2) 的贡献. 这就是一个区间加法.
有个小问题, 如果 r 减少那么相应的可选的其实会变多, 但是此处我们不统计, 那些会在以另一个字符结尾的别的节点上考虑到.
这样第一问就解决了. 第二问? 发现问题很大... 一个节点的每个串都是结尾相同, 开头不同. 那么开头的权值之积就不好维护.
因为结尾相同, 所以考虑反着建后缀自动机, 然后就变成了开头相同了. 那么如何维护乘积最大值呢?
考虑到一个节点的末尾所在位置, 也就是它的 right 集合. 显然就是 fail 树的子树中所有在主链上的节点.
于是每个点维护子树最大值即可. 因为有负数所以还要最小值.
这样一个节点对第二问的贡献就是区间取 max 了. 这两问的操作都可以用线段树搞定.
啊! 我完全独立的一 A 了这道题! 鼓掌! 啪啪啪啪啪啪......
- #include <cstdio>
- #include <cstring>
- #include <algorithm>
- typedef long long LL;
- const int N = 600010;
- int tr[N][26], tot = 1, fail[N], len[N], cnt[N], bin[N], topo[N], last = 1, val[N], n;
- int max1[N], max2[N], min1[N], min2[N];
- char str[N];
- LL tag[N <<1], large[N << 1];
- inline void insert(char c) {
- int f = c - 'a';
- int p = last, np = ++tot;
- last = np;
- len[np] = len[p] + 1;
- cnt[np] = 1;
- while(p && !tr[p][f]) {
- tr[p][f] = np;
- p = fail[p];
- }
- if(!p) {
- fail[np] = 1;
- }
- else {
- int Q = tr[p][f];
- if(len[Q] == len[p] + 1) {
- fail[np] = Q;
- }
- else {
- int nQ = ++tot;
- len[nQ] = len[p] + 1;
- fail[nQ] = fail[Q];
- fail[Q] = fail[np] = nQ;
- memcpy(tr[nQ], tr[Q], sizeof(tr[Q]));
- while(tr[p][f] == Q) {
- tr[p][f] = nQ;
- p = fail[p];
- }
- }
- }
- return;
- }
- inline void update(int x, int y) {
- int t[4];
- t[0] = max1[x];
- t[1] = max2[x];
- t[2] = max1[y];
- t[3] = max2[y];
- std::sort(t, t + 4);
- max1[x] = t[3];
- max2[x] = t[2];
- t[0] = min1[x];
- t[1] = min2[x];
- t[2] = min1[y];
- t[3] = min2[y];
- std::sort(t, t + 4);
- min1[x] = t[0];
- min2[x] = t[1];
- return;
- }
- inline void prework() {
- for(int i = 1; i <= tot; i++) {
- bin[len[i]]++;
- }
- for(int i = 1; i <= tot; i++) {
- bin[i] += bin[i - 1];
- }
- for(int i = 1; i <= tot; i++) {
- topo[bin[len[i]]--] = i;
- }
- for(int i = tot; i>= 1; i--) {
- int a = topo[i];
- cnt[fail[a]] += cnt[a];
- update(fail[a], a);
- }
- return;
- }
- inline void pushdown(int o) {
- if(tag[o]) {
- tag[o <<1] += tag[o];
- tag[o << 1 | 1] += tag[o];
- tag[o] = 0;
- }
- large[o << 1] = std::max(large[o << 1], large[o]);
- large[o << 1 | 1] = std::max(large[o << 1 | 1], large[o]);
- return;
- }
- void add(int L, int R, LL v, int l, int r, int o) {
- if(L <= l && r <= R) {
- tag[o] += v;
- return;
- }
- int mid = (l + r)>> 1;
- pushdown(o);
- if(L <= mid) {
- add(L, R, v, l, mid, o <<1);
- }
- if(mid < R) {
- add(L, R, v, mid + 1, r, o << 1 | 1);
- }
- return;
- }
- void out(int l, int r, int o) {
- if(l == r) {
- if(r != n) {
- printf("%lld %lld \n", tag[o], tag[o] ? large[o] : 0);
- }
- return;
- }
- int mid = (l + r)>> 1;
- pushdown(o);
- out(l, mid, o <<1);
- out(mid + 1, r, o << 1 | 1);
- return;
- }
- void change(int L, int R, LL v, int l, int r, int o) {
- if(v <= large[o]) {
- return;
- }
- if(L <= l && r <= R) {
- large[o] = v;
- return;
- }
- int mid = (l + r)>> 1;
- pushdown(o);
- if(L <= mid) {
- change(L, R, v, l, mid, o <<1);
- }
- if(mid < R) {
- change(L, R, v, mid + 1, r, o << 1 | 1);
- }
- return;
- }
- int main() {
- memset(max1, ~0x3f, sizeof(max1));
- memset(max2, ~0x3f, sizeof(max2));
- memset(min1, 0x3f, sizeof(min1));
- memset(min2, 0x3f, sizeof(min2));
- memset(large, ~0x3f, sizeof(large));
- scanf("%d", &n);
- scanf("%s", str + 1);
- int l1 = -0x3f3f3f3f, l2 = -0x3f3f3f3f, s1 = 0x3f3f3f3f, s2 = 0x3f3f3f3f;
- for(int i = 1; i <= n; i++) {
- scanf("%d", &val[i]);
- if(l1 < val[i]) {
- l2 = l1;
- l1 = val[i];
- }
- else if(l2 < val[i]) {
- l2 = val[i];
- }
- if(s1> val[i]) {
- s2 = s1;
- s1 = val[i];
- }
- else if(s2> val[i]) {
- s2 = val[i];
- }
- }
- for(int i = n; i>= 1; i--) {
- insert(str[i]);
- max1[last] = min1[last] = val[i];
- }
- prework();
- //
- for(int i = 2; i <= tot; i++) {
- if(cnt[i] < 2) {
- continue;
- }
- // len[fail[i]] + 1 ~ len[i]
- add(len[fail[i]] + 1, len[i], 1ll * cnt[i] * (cnt[i] - 1) / 2, 1, n, 1);
- change(len[fail[i]] + 1, len[i], std::max(1ll * max1[i] * max2[i], 1ll * min1[i] * min2[i]), 1, n, 1);
- }
- printf("%lld %lld \n", 1ll * n * (n - 1) / 2, std::max(1ll * l1 * l2, 1ll * s1 * s2));
- out(1, n, 1);
- return 0;
- }
- View Code
来源: http://www.bubuko.com/infodetail-2936752.html