SAM 瞎学笔记

博主纯在放屁

P3804 【模板】后缀自动机(SAM)

给你一个字符串 S,要找 出现次数大于 1 的所有子串,求

$$
\max\big(\text{出现次数} \times \text{长度}\big)
$$

先把 SAM 建出来,每个 SAM 状态 u 表示的所有子串长度范围是:

$$
\text{len[link[u]]} + 1 \quad\text{到}\quad \text{len[u]}
$$

并且它们的出现次数都是 cnt[u](假设我们已经知道了)。

我们要找的是:

$$
\max_{\text{任意子串出现次数}>1} \ (cnt[u] \times L)
$$

而对于一个状态,最大的 L 是 len[u]

所以遍历所有节点,若 cnt[u] >= 2

$$
\text{ans} = \max(\text{ans}, \text{cnt[u]} \times \text{len[u]})
$$

怎么求 cnt[u] 呢,你可以在 DAG 上面 dfs,或者按 len 从大到小排序,依次 cnt[link[u]] += cnt[u]

板子是贺的 jls 的,下面就不再贴这个板子了()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
struct SAM {
static constexpr int ALPHABET_SIZE = 26;
struct Node {
int len;
int link;
std::array<int, ALPHABET_SIZE> next;
Node() : len{}, link{}, next{} {}
};
std::vector<Node> t;
SAM() { init(); }
void init() {
t.assign(2, Node());
t[0].next.fill(1);
t[0].len = -1;
}
int newNode() {
t.emplace_back();
return t.size() - 1;
}
int extend(int p, int c) {
if (t[p].next[c]) {
int q = t[p].next[c];
if (t[q].len == t[p].len + 1) {
return q;
}
int r = newNode();
t[r].len = t[p].len + 1;
t[r].link = t[q].link;
t[r].next = t[q].next;
t[q].link = r;
while (t[p].next[c] == q) {
t[p].next[c] = r;
p = t[p].link;
}
return r;
}
int cur = newNode();
t[cur].len = t[p].len + 1;
while (!t[p].next[c]) {
t[p].next[c] = cur;
p = t[p].link;
}
t[cur].link = extend(p, c);
return cur;
}
int next(int p, int x) { return t[p].next[x]; }
int link(int p) { return t[p].link; }
int len(int p) { return t[p].len; }
int size() { return t.size(); }
};

void solve() {
string s; cin >> s;
SAM sam;

vector<int> endpos;
for (int p = 1; auto ch : s) {
p = sam.extend(p, ch - 'a');
endpos.eb(p);
}

vector<int> cnt(sam.size(), 0);
for (int s : endpos) {
cnt[s] += 1;
}

vector<int> order(sam.size());
ranges::iota(order, 0);
ranges::sort(order, [&](int a, int b) {
return sam.len(a) > sam.len(b);
});

for (int u : order) {
if (u == 0) continue;
cnt[sam.link(u)] += cnt[u];
}

int ans = 0;
for (int u = 1; u < sam.size(); u++) {
if (cnt[u] >= 2) {
ans = max(ans, cnt[u] * sam.len(u));
}
}

cout << ans << '\n';
}

P2408 不同子串个数

给你一个长为 $n$ 的字符串,求不同的子串的个数。

根据 SAM 的性质,每个节点(状态)对应一个 endpos 等价类,且长度范围是 $[len(link[u]) + 1, len[u]]$

而且 每个状态的子串集合与其他状态的子串集合是互不重叠的(即不同节点表示的子串类一定不同)。

所以直接枚举节点求和即可。

因为 jls 的板子节点 1 代表的是空串,枚举的时候 $u$ 从 $2$ 开始。或者从 1 开始,最后 ans–。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void solve() {
int n = read();
string s; cin >> s;
SAM sam;

for (int p = 1; auto ch : s) {
p = sam.extend(p, ch - 'a');
}

int ans = 0;
for (int u = 2; u < sam.size(); u++) {
ans += sam.len(u) - sam.len(sam.link(u));
}
cout << ans << '\n';
}

SP1811 LCS - Longest Common Substring

输入两个字符串,输出它们的最长公共子串长度,若不存在公共子串则输出 0。

  1. 用第一个串 $A$ 建立后缀自动机(SAM)。
  2. 用第二个串 $B$ 在该 SAM 上匹配:
    • v 为当前状态(初始为 1,空串状态),len 为当前匹配长度。
    • 对于 B 的每个字符 c
      • next[v][c] 存在,v = next[v][c]len++
      • 否则沿 link 回退直到找到可走的 c,若退到哨兵节点则重置到空串状态,len=0,否则 len = len(v) + 1v = next[v][c]
    • 全程维护 ans = max(ans, len)
  3. 输出 ans

PS: 这个代码在 SPOJ 上面交一定要选 C++14 (clang 8.0),不然的话不知道为什么会 RE,很离谱。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include <bits/stdc++.h>
using i64 = long long;
using u64 = unsigned long long;
using i128 = __int128;
#define int i64
#define pb push_back
#define ep emplace
#define eb emplace_back
using namespace std;
int read(int x = 0, int f = 0, char ch = getchar()) {
while (ch < 48 or 57 < ch) f = ch == 45, ch = getchar();
while (48 <= ch and ch <= 57) x = x * 10 + ch - 48, ch = getchar();
return f ? -x : x;
}
#define debug(x) cout << #x << " = " << x << "\n";
#define vdebug(a) cout << #a << " = "; for (auto x : a) cout << x << " "; cout << "\n";
template <class T1, class T2> ostream &operator<<(ostream &os, const std::pair<T1, T2> &a) { return os << "(" << a.first << ", " << a.second << ")"; };
template <class T> ostream &operator<<(ostream &os, const vector<T> &as) { const int sz = as.size(); os << "["; for (int i = 0; i < sz; i++) { if (i >= 256) { os << ", ..."; break; } if (i > 0) { os << ", "; } os << as[i]; } return os << "]"; }
template <class T> void pv(T a, T b) { for (T i = a; i != b; i++) cerr << *i << " "; cerr << '\n'; }
using pii = pair<int, int>;
const int inf = 1e18;

struct SAM {
static constexpr int ALPHABET_SIZE = 26;
struct Node {
int len;
int link;
std::array<int, ALPHABET_SIZE> next;
Node() : len{}, link{}, next{} {}
};
std::vector<Node> t;
SAM() { init(); }
void init() {
t.assign(2, Node());
t[0].next.fill(1);
t[0].len = -1;
}
int newNode() {
t.emplace_back();
return t.size() - 1;
}
int extend(int p, int c) {
if (t[p].next[c]) {
int q = t[p].next[c];
if (t[q].len == t[p].len + 1) {
return q;
}
int r = newNode();
t[r].len = t[p].len + 1;
t[r].link = t[q].link;
t[r].next = t[q].next;
t[q].link = r;
while (t[p].next[c] == q) {
t[p].next[c] = r;
p = t[p].link;
}
return r;
}
int cur = newNode();
t[cur].len = t[p].len + 1;
while (!t[p].next[c]) {
t[p].next[c] = cur;
p = t[p].link;
}
t[cur].link = extend(p, c);
return cur;
}
int next(int p, int x) { return t[p].next[x]; }
int link(int p) { return t[p].link; }
int len(int p) { return t[p].len; }
int size() { return t.size(); }
};

void solve() {
string A, B; cin >> A >> B;
SAM sam;
for (int p = 1; auto ch : A) {
if (ch < 'a' or ch > 'z') continue;
p = sam.extend(p, ch - 'a');
}

int ans = 0;
for (int v = 1, len = 0; auto ch : B) {
int c = ch - 'a';
if (sam.next(v, c)) {
v = sam.next(v, c);
len++;
} else {
while (v and !sam.next(v, c)) {
v = sam.link(v);
}
if (v == 0) {
v = 1, len = 0;
continue;
}

len = sam.len(v) + 1;
v = sam.next(v, c);
}
ans = max(ans, len);
}
cout << ans << '\n';
}

signed main() {
// for (int T = read(); T--; solve());
solve();
return 0;
}

P4070 [SDOI2016] 生成魔咒

总共 $n$ 次操作,每次操作为加入一个字符。每次操作后都需要求出,当前的字符串的本质不同子串数量。

是前面求不同子串个数的在线版,边做边统计即可。

需要注意的是本题字符集大小是 $O(1e5)$ 的。需要把哥哥板子改成 map / umap 的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
struct SAM {
// static constexpr int ALPHABET_SIZE = 26;
struct Node {
int len;
int link;
// std::array<int, ALPHABET_SIZE> next;
map<int, int> next;
Node() : len{}, link{}, next{} {}
};
std::vector<Node> t;
SAM() { init(); }
void init() {
t.assign(2, Node());
// t[0].next.fill(1);
t[0].len = -1;
}
int newNode() {
t.emplace_back();
return t.size() - 1;
}
int extend(int p, int c) {
if (p == 0) return 1;

if (t[p].next.count(c)) {
int q = t[p].next[c];
if (t[q].len == t[p].len + 1) {
return q;
}
int r = newNode();
t[r].len = t[p].len + 1;
t[r].link = t[q].link;
t[r].next = t[q].next;
t[q].link = r;
while (t[p].next[c] == q) {
t[p].next[c] = r;
p = t[p].link;
}
return r;
}
int cur = newNode();
t[cur].len = t[p].len + 1;
while (!t[p].next.count(c)) {
t[p].next[c] = cur;
p = t[p].link;
}
t[cur].link = extend(p, c);
return cur;
}
int next(int p, int x) { return t[p].next[x]; }
int link(int p) { return t[p].link; }
int len(int p) { return t[p].len; }
int size() { return t.size(); }
};

void solve() {
int n = read();
SAM sam;

int ans = 0;
for (int i = 1, p = 1; i <= n; i++) {
int x = read();
int ch = x;
p = sam.extend(p, ch);
ans += sam.len(p) - sam.len(sam.link(p));
cout << ans << '\n';
}
}

P5341 [TJOI2019] 甲苯先生和大中锋的字符串

我们有一个字符串 S 和一个数 k,要找出出现恰好 k 次的所有子串,然后按照“子串长度”分组,数一数每个长度的子串数量,取数量最多的长度输出(多解选最长),如果没有恰好 k 次的子串就输出 -1。

首先前面还是统计出现次数,后面的东东就是对每个状态,把他对应长度区间的出现次数 +1,差分维护一下即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
void solve() {
string s; cin >> s;
int k = read();
SAM sam;

vector<int> endpos;
for (int p = 1; auto ch : s) {
p = sam.extend(p, ch - 'a');
endpos.eb(p);
}

vector<int> cnt(sam.size(), 0);
for (int s : endpos) {
cnt[s] += 1;
}

vector<int> order(sam.size());
ranges::iota(order, 0);
ranges::sort(order, [&](int a, int b) {
return sam.len(a) > sam.len(b);
});

for (int u : order) {
if (u == 0) continue;
cnt[sam.link(u)] += cnt[u];
}


int n = s.size();
vector<int> c(n + 5);
for (int u = 1; u < sam.size(); u++) {
if (cnt[u] != k) continue;

int p = sam.link(u);
c[sam.len(p) + 1]++;
c[sam.len(u) + 1]--;
}

int ans = 0, res = -1;
for (int i = 1; i <= n; i++) {
c[i] += c[i - 1];
if (c[i] >= res or (c[i] == res and i > ans)) res = c[i], ans = i;
}
if (res == 0) ans = -1;

cout << ans << '\n';
}

802I - Fake News (hard)

求每个本质不同子串的出现次数平方和

模版题改改即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
void solve() {
string s; cin >> s;
SAM sam;

vector<int> endpos;
for (int p = 1; auto ch : s) {
p = sam.extend(p, ch - 'a');
endpos.eb(p);
}

vector<int> cnt(sam.size(), 0);
for (int s : endpos) {
cnt[s] += 1;
}

vector<int> order(sam.size());
ranges::iota(order, 0);
ranges::sort(order, [&](int a, int b) {
return sam.len(a) > sam.len(b);
});

for (int u : order) {
if (u == 0) continue;
cnt[sam.link(u)] += cnt[u];
}

int ans = 0;
for (int u = 2; u < sam.size(); u++) {
int t = cnt[u] * cnt[u];
t *= sam.len(u) - sam.len(sam.link(u));
ans += t;
}

cout << ans << '\n';
}

SP8093 JZPGYZ - Sevenk Love Oimaster

给定 $n$ 个文本串和 $m$ 个模式串,求对于每个模式串,有几个文本串包含它。

比较多的做法是 dfs 序 然后区间数颜色,一只 log。

但是好像直接暴力挑 parent tree 也可以,复杂度是 一个 sqrt 的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
struct SAM {
// static constexpr int ALPHABET_SIZE = 26;
struct Node {
int len;
int link;
// std::array<int, ALPHABET_SIZE> next;
map<int, int> next;
Node() : len{}, link{}, next{} {}
};
std::vector<Node> t;
SAM() { init(); }
void init() {
t.assign(2, Node());
// t[0].next.fill(1);
t[0].len = -1;
}
int newNode() {
t.emplace_back();
return t.size() - 1;
}
int extend(int p, int c) {
if (p == 0) return 1;

if (t[p].next.count(c)) {
int q = t[p].next[c];
if (t[q].len == t[p].len + 1) {
return q;
}
int r = newNode();
t[r].len = t[p].len + 1;
t[r].link = t[q].link;
t[r].next = t[q].next;
t[q].link = r;
while (t[p].next[c] == q) {
t[p].next[c] = r;
p = t[p].link;
}
return r;
}
int cur = newNode();
t[cur].len = t[p].len + 1;
while (!t[p].next.count(c)) {
t[p].next[c] = cur;
p = t[p].link;
}
t[cur].link = extend(p, c);
return cur;
}
int next(int p, int x) { return t[p].next[x]; }
int link(int p) { return t[p].link; }
int len(int p) { return t[p].len; }
int size() { return t.size(); }
};

void solve() {
int n = read();
int q = read();
vector<string> a(n);
for (auto &s : a) cin >> s;

SAM sam;
vector<int> endpos;
for (const auto &s : a) {
int p = 1;
for (auto ch : s) {
p = sam.extend(p, ch);
endpos.eb(p);
}
}

int m = sam.size();
vector<int> vis(sam.size(), -1);
vector<int> sz(sam.size());

int i = 0;
for (auto s : a) {
int p = 1;
for (auto ch : s) {
p = sam.next(p, ch);

for (int now = p; now and vis[now] != i; now = sam.link(now)) {
vis[now] = i;
sz[now]++;
}
}
i++;
}

while (q--) {
string s; cin >> s;
int p = 1;
bool f = 1;
for (auto ch : s) {
if (!sam.t[p].next.count(ch)) {
f = 0;
break;
}
p = sam.next(p, ch);
}

int res = f ? sz[p] : 0;
cout << res << '\n';
}
}
Author

TosakaUCW

Posted on

2025-09-01

Updated on

2025-09-11

Licensed under

Comments