KMP 算法 从零开始
大部分来自他人博客, 蒟蒻只是总结学习
引言
字符串匹配. 给你两个字符串, 寻找其中一个字符串是否包含另一个字符串, 如果包含, 返回包含的起始位置.
- char *str = "bacbababadababacambabacaddababacasdsd";
- char *ptr = "ababaca";
暴力解法
如果当前字符匹配成功 (即 S[i] == P[j]), 则 i++,j++, 继续匹配下一个字符;
如果失配 (即 S[i]! = P[j]), 令 i = i - (j - 1),j = 0. 相当于每次匹配失败时, i 回溯, j 被置为 0.
来看看时间复杂度: 最坏情况下为 O(n*m)
所以有没有一种改进的算法
改进方法
可以实现复杂度为 O(m+n), 为何简化了时间复杂度:
KMP 算法主要是取消了指针的回溯, 充分利用了目标字符串 ptr 的性质 (比如里面部分字符串的重复性, 即使不存在重复字段, 在比较时, 实现最大的移动量). 每趟匹配过程中出现字符比较不等时, 不回溯主指针 i, 利用已得到的 "部分匹配" 结果将模式向右滑动尽可能远的一段距离, 继续进行比较.
具体概念上的我不想深究, 从代码开始理解
---
代码
KmpSearch 函数
假设现在文本串 S 匹配到 i 位置, 模式串 P 匹配到 j 位置
如果 j = -1, 或者当前字符匹配成功 (即 S[i] == P[j]), 都令 i++,j++, 继续匹配下一个字符;
如果 j != -1, 且当前字符匹配失败 (即 S[i] != P[j]), 则令 i 不变, j = next[j]. 此举意味着失配时, 模式串 P 相对于文本串 S 向右移动了 j - next [j] 位.
换言之, 当匹配失败时, 模式串向右移动的位数为: 失配字符所在位置 - 失配字符对应的 next 值, 即移动的实际位数为: j - next[j], 且此值大于等于 1.
所以 next 数组各值的含义: 代表当前字符之前的字符串中, 有多大长度的相同前缀后缀. 例如如果 next [j] = k, 代表 j 之前的字符串中有最大长度为 k 的相同前缀后缀.
此也意味着在某个字符失配时, 该字符对应的 next 值会告诉你下一步匹配中, 模式串应该跳到哪个位置 (跳到 next [j] 的位置). 如果 next [j] 等于 0 或 - 1, 则跳到模式串的开头字符, 若 next [j] = k 且 k> 0, 代表下次匹配跳到 j 之前的某个字符, 而不是跳到开头, 且具体跳过了 k 个字符.
- int KmpSearch(char* s, char* p)
- {
- int i = 0;
- int j = 0;
- int sLen = strlen(s);
- int pLen = strlen(p);
- while (i <sLen && j < pLen)
- {
- //如果 j = -1, 或者当前字符匹配成功 (即 S[i] == P[j]), 都令 i++,j++
- if (j == -1 || s[i] == p[j])
- {
- i++;
- j++;
- }
- else
- {
- //如果 j != -1, 且当前字符匹配失败 (即 S[i] != P[j]), 则令 i 不变, j = next[j]
- //next[j] 即为 j 所对应的 next 值
- j = next[j];
- }
- }
- if (j == pLen)
- return i - j;
- else
- return -1;
- }
getnext() 函数
(1) next[0] = -1;
(2) 设 next[j] = k, 则 next[j+1] = ?
令 j=j+1,k=k+1;
若 pk=pj, 则有 "p1...pk-1pk"="pj-k+1...pj-1pj" ,
next[j]=k;
若 pk+1pj+1, 可把求 next 值问题看成是一个模式匹配问题, 整个模式串既是主串, 又是子串.
即使得 k=next[k], 回溯求得最长前缀等于最长后缀的下标
j | 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
---|---|
模式串 | 0 a b c a a b b c a b c a a b d a |
next[j] | -1 0 0 0 1 1 2 0 0 1 2 3 4 5 6 0 1 |
- void getnext(int*next,char*ctr){
- next[0]=-1;
- int j=0,k=-1,len=strlen(ctr);
- while(j<len){
- if(k==-1||ctr[j]==ctr[k]){
- j++;
- k++;
- next[j]=k;
- }
- else k=next[k];
- }
- }
完整代码
- #include <bits/stdc++.h>
- using namespace std;
- const int maxn=1e6+7;
- int a[maxn],b[maxn];
- int nxt[10005],n,m;
- void getnext(){
- int i=0,j=-1;
- nxt[0]=-1;
- while(i<m){
- if(j==-1||b[i]==b[j]){
- i++,j++;
- if(b[i]==b[j])nxt[i]=nxt[j];
- else nxt[i]=j;
- }
- else j=nxt[j];
- }
- }
- int kmp(){
- int i=0,j=0;
- getnext();
- while(i<n){
- if(a[i]==b[j]||j==-1)i++,j++;
- else j=nxt[j];
- if(j==m)return i-j+1;
- }
- return -1;
- }
- int main(){
- int t;
- scanf("%d",&t);
- while(t--){
- scanf("%d%d",&n,&m);
- for(int i=0;i<n;i++)scanf("%d",&a[i]);
- for(int i=0;i<m;i++)scanf("%d",&b[i]);
- if(n<m)printf("-1\n");
- else printf("%d\n",kmp());
- }
- return 0;
- }
练习
牛客三 E
题目: 给你一个字符串 S, 你要对字符串 S 的每一位 i 将前 i 位的字符串移动到尾部形成一个新的字符串, 如果形成的字符串相同则归为一类 Li. 现在让你将 Li 类按照字典序排序, 并让你输出每一类的数量和每一类中字符串对应的下标 i
很多人用哈希, 正解是 KMP 的 next 数组的运用. 通过观察, 题目问的就是求字符串的循环节, 而 next 数组就是前后缀的运用. 然后有个结论: 如果对于一个长度为 L 的字符串, 如果 L%(L-next[L])==0 则代表它具有循环节, 且循环节的长度为 L-next[L].
- #include <bits/stdc++.h>
- using namespace std;
- const int maxn=1e6+7;
- char b[maxn];
- int net[maxn],len;
- template<class T>
- void read(T &res)
- {
- res = 0;
- char c = getchar();
- T f = 1;
- while(c <'0' || c> '9')
- {
- if(c == '-') f = -1;
- c = getchar();
- }
- while(c>= '0' && c <= '9')
- {
- res = res * 10 + c - '0';
- c = getchar();
- }
- res *= f;
- }
- template<class T>
- void out(T x)
- {
- if(x <0)
- {
- putchar('-');
- x = -x;
- }
- if(x>= 10)
- {
- out(x / 10);
- }
- putchar('0' + x % 10);
- }
- void get_next(char*ctr)
- {
- net[0]=-1;
- int j=0,k=-1;
- len=strlen(ctr);
- while(j<len)
- {
- if(k==-1||ctr[j]==ctr[k])
- {
- j++;
- k++;
- net[j]=k;
- }
- else k=net[k];
- }
- }
- int main()
- {
- scanf("%s",b);
- get_next(b);
- int tmp=len-net[len];
- if(len%tmp==0&&tmp!=len)
- {
- out(tmp);
- puts(" ");
- for(int i=0; i<tmp; i++)
- {
- out(len/tmp);
- printf(" ");
- for(int j=i;j<len;j+=tmp){
- out(j);
- //
- printf(" ");
- }
- printf("\n");
- }
- }
- else{
- out(len);
- // puts(" ");
- printf("\n");
- for(int i=0;i<len;i++){
- out(1);
- printf(" ");
- out(i);
- printf("\n");
- }
- }
- //
- return 0;
- }
来源: http://www.bubuko.com/infodetail-2705560.html