史上全网最清晰后缀自动机学习(四)后缀自动机里的DAG结构
- 2019 年 12 月 20 日
- 筆記
缘起
通过【1】、【2】、【3】, 我们学习了后缀自动机这种精巧的数据结构. 它本质上是利用了根据endpos对所有子串的等价类划分这一优美的结构. 【1】我们学习了SAM的基本概念, 【2】我们学习了SAM的O(n) 构造算法. 【3】我们学习了SAM上类似于ac自动机のfail树的一种数据结构——slink树. 本文继续膜SAM. hihocoder #1457 : 后缀自动机四·重复旋律7
分析
时间限制:15000ms 单点时限:3000ms 内存限制:512MB 描述 小Hi平时的一大兴趣爱好就是演奏钢琴。我们知道一段音乐旋律可以被表示为一段数构成的数列。 神奇的是小Hi发现了一部名字叫《十进制进行曲大全》的作品集,顾名思义,这部作品集里有许多作品,但是所有的 作品有一个共同特征:只用了十个音符,所有的音符都表示成0-9的数字。 现在小Hi想知道这部作品集中所有不同的旋律的“和”(也就是把串看成数字,在十进制下的求和,允许有前导0)。答 案有可能很大,我们需要对(10^9 + 7)取摸。 解题方法提示 输入 第一行,一个整数N,表示有N部作品。 接下来N行,每行包含一个由数字0-9构成的字符串S。 所有字符串长度和不超过 1000000。 输出 共一行,一个整数,表示答案 mod (10^9 + 7)。 样例输入 2 101 09 样例输出 131
首先, 这个问题和【1】~【3】中介绍的问题有一个不一样的地方——就是本题是多串. 而【1】~【3】是单串.
首先, 我们先说一下, 在【3】中我们也已经提及了. 文本串S的后缀自动机是一部恰好只能识别S的全部子串的机器 , 所以今后对于涉及子串的题目, 我们可以往后缀自动机方面去考虑.
而本题正是涉及子串的题目(因为允许有前导0, 即 0342 和 342 也看做不同的子串进行求和的.)所以考虑用sam去解决本题.
我们注意到 后缀自动机的所有状态中包含的子串的集合恰好对应原串的所有不重复子串
首先来考虑单个串的问题. 我们首先构建该串的SAM. 然后只需要求出所有节点中的子串的和——记做该节点的sum属性, 然后把所有节点的sum求和就是答案. 但是这种做法复杂度显然高了一点. 因为你还要维护每个节点的子串集合吧?
等等! 其实大家不觉得很像吗? 【3】中需要计算的是SAM节点上的endpos集合的大小, 本文需要计算的是SAM节点上的endpos集合中元素的和. 等等! 其实大家不觉得很像吗? 【3】中需要计算的是SAM节点上的endpos集合的大小, 本文需要计算的是SAM节点上的endpos集合中元素的和.
涉及SAM的题目多半是涉及子串的统计量的求解. 而子串对应到SAM的节点上去, 所以多半化归为SAM节点的某种统计量(其实就是SamNode 数据结构中的一个属性)的求解. 而SAM节点的某种统计量的求解不能暴力去求解, 因为SAM节点的个数已经是O(n)的了. 所以稍有暴力, 就堕入O(n^2)去了. 一般都是要用到slink树(SAM站在slink的视角来看待)或者DAG(SAM站在trans的视角来看待)进行dfs(dfs和DP本质没有区别)或者DP来求解
本题就要用这种思路. 既然不能暴力, 我们考虑一个SAM节点的sum属性该怎么从其他节点快速维护出来. 我们注意到了trans的意义——读入一个字符, 跳到某个状态. 而读入一个字符意思就是拼接上一个字符. 而这种操作,我们可以非常快的使用洪特规则来维护. 还是举个例子吧
假设S="1122124",其实就是我们熟悉的例子S="aabbabd"啦。我们去掉slink, 仅仅保留trans, 得到它的SAM如下

以及
状态 |
子串 |
endpos |
sum |
---|---|---|---|
0 |
空串 |
{0,1,2,3,4,5,6,7} |
0 |
1 |
1 |
{1,2,5} |
1 |
2 |
11 |
{2} |
11 |
3 |
112 |
{3} |
112 |
4 |
1122,122,22 |
{4} |
1266 |
5 |
2 |
{3,4,6} |
2 |
6 |
11221,1221,221,21 |
{5} |
12684 |
7 |
112212,12212,2212,212 |
{6} |
126848 |
8 |
12 |
{3,6} |
12 |
9 |
1122124,122124,22124,2124,124,24,4 |
{7} |
1248648 |
这里我们看看如何快速求出状态6的sum属性. 从图1可以看出, 能到达状态6的状态只有状态4或者状态5. 其中状态4是读(拼)入(接)字符"1"到达的6, 状态5也是读(拼)入(接)字符"1"到达的6
那么
状态6的sum其实=(状态5的sum*10+读入的1*状态6中包含的子串个数)+(状态4的sum*10+读入的1*状态4中包含的子串个数) = (2*10+3)+(1266*10+1) = 12661+23=12684
这不就恰好是状态6的sum了么? 所以每个SAM节点需要维护的业务字段有 sum、longest、shortest(节点包含的子串个数自然就是longest-shortest+1)
既然SAM站在trans的角度来看是一个DAG,我们自然考虑使用拓扑排序DP求解所有节点的sum属性. 那么问题来了, DP的初始化怎么搞? 注意, DAG中入度为0的节点只有起点0而已! 因为自动机是从0开始的呀~ 不可能有不从任何状态转移过来的非0状态存在! 宇宙中不允许有这么牛逼的状态存在!
而状态节点0的sum属性显然是0.
好了, 一旦我们通过拓扑排序顺便DP求出了所有节点的sum域之后, 则只需要将所有节点的sum节点求和即可得到答案.
整个算法的复杂度是O(len(s))的. len(s)是单串s的长度.
但是题目要求的是多串~ 不同串之间可以有相同的子串. 但是题目要求的是多串~ 不同串之间可以有相同的子串.
我们可以效仿后缀数组或者后缀树处理多串的方法(也就是所谓的广义), 把所有串用冒号":" (":"的ACII码是58,也就是"9"的ASCII码+1,方便处理) 连接以来. 例如以两个串"12"和"234"为例,"12:234"的SAM(只画了trans,没有画slink)如下

状态 |
子串 |
endpos |
|validnum| |
sum |
---|---|---|---|---|
S |
空串 |
{0,1,2,3,4,5,6} |
1 |
0 |
1 |
1 |
{1} |
1 |
1 |
2 |
12 |
{2} |
1 |
12 |
3 |
12:,2:,: |
{3} |
0 |
0 |
4 |
12:2,2:2,:2 |
{4} |
0 |
0 |
5 |
2 |
{2,4} |
1 |
2 |
6 |
12:23,2:23,:23,23,3 |
{5} |
2 |
26 |
7 |
12:234,2:234,:234,234,34,4 |
{6} |
3 |
272 |
上面的表格中, validnum是该节点中不包含":"的子串的个数. 例如状态6 中只有"23"和"3"两个不含":"的子串, 所以状态6的validanum=2, 而节点的sum域只统计了节点中不包含":"的子串的数字之和(这是显然的, 因为带了":"的子串你也不晓得它应该是什么数值啊~), 例如状态6的sum是23+3=26, 而不可能将"12:23, 2:23, :23" 这三个子串记录进来.
注意, 因为冒号的引入, 所以字符的数量可能扩倍.
细想一下之前的拓扑排序+DP, 对于上面带":"的后缀自动机显然也是适用的. 即我们依然有DP公式

image
所以最后剩下的一个问题是——如何维护每个节点的validnum域? 它似乎并不像单串那般 longest-shortest+1这么简单了.
我们换个角度来看这个问题. 其实一个节点中如果有子串s的话,则从自动机0开始逐个喂入s的字符的话,则最终自动机将停止在这个节点.即读入子串的过程其实就是在自动机上画出路径的过程.显然一个节点中的每个子串和这些从0出发通过trans跳转到达此节点的道路一一对应.所以一个节点中包含子串的个数不仅仅可以通过shortest、longest来求出,还可以从路径条数这一图论的观点得到.
所以一个节点的validnum域就是从自动机的0节点出发, 通过trans转移, 但是不经过":"的弧到达该节点的路径条数. 所以一个节点的validnum域就是从自动机的0节点出发, 通过trans转移, 但是不经过":"的弧到达该节点的路径条数.
而SAM站在trans的角度看就是一个DAG啊, 这一问题完全可以使用拓扑排序简单求出.
至此, 此题所有问题都已经解开, 可以开始愉快的切代码了~
//#include "stdafx.h" #include <stdio.h> #include <string.h> #include <stack> using namespace std; //#define LOCAL typedef long long ll; const int maxn = 1e6+5, SZ = 11, MOD = 1e9+7; int n, indeg[maxn<<2], tindeg[maxn<<2]; char s[maxn<<1]; struct SamNode { int trans[SZ], slink; int shortest, longest; ll sum, validnum; }sam[maxn<<2]; int newnode(int shortest, int longest, int *trans, int slink) { sam[n].shortest = shortest; sam[n].longest = longest; sam[n].slink = slink; trans?memcpy(sam[n].trans, trans, SZ*sizeof(int)):memset(sam[n].trans, -1, SZ*sizeof(int)); return n++; } int insert(char ch, int u) { int c = ch ^ 48; int z = newnode(-1, sam[u].longest+1, 0, -1); int v = u; while(~v && !~sam[v].trans[c]) { sam[v].trans[c] = z; v = sam[v].slink; } if (!~v) { sam[z].slink = 0; sam[z].shortest = 1; return z; } int x = sam[v].trans[c]; if (sam[v].longest+1 == sam[x].longest) { sam[z].slink = x; sam[z].shortest = sam[x].longest + 1; return z; } int y = newnode(-1, sam[v].longest+1, sam[x].trans, sam[x].slink); sam[x].slink = sam[z].slink = y; sam[x].shortest = sam[z].shortest = sam[y].longest + 1; while(~v && sam[v].trans[c] == x) { sam[v].trans[c] = y; v = sam[v].slink; } sam[y].shortest = sam[sam[y].slink].longest + 1; return z; } ll topsort() { stack<int> stk; stk.push(0); while(!stk.empty()) { int top = stk.top(); stk.pop(); for (int i = 0, to;i<SZ; i++) { to = sam[top].trans[i]; if (!~to) { continue; } if (i ^ SZ - 1) // 不能用带 ":" 的弧转移过来 { sam[to].sum += (sam[top].sum<<3)+(sam[top].sum<<1) + sam[top].validnum*i; sam[to].sum %= MOD; } if (!--indeg[to]) { stk.push(to); } } } ll ans = 0; for (int i = 1;i<n; i++) { ans += sam[i].sum; ans %= MOD; } return ans; } void topsort1() // 求出0到达每个节点不经过":"的弧的方法数 { for (int i = 0;i<n; i++) { for (int j = 0; j<SZ; j++) { ++indeg[sam[i].trans[j]]; } } memcpy(tindeg, indeg, sizeof(indeg)); stack<int> stk; stk.push(0); sam[0].validnum = 1; while(!stk.empty()) { int top = stk.top(); stk.pop(); for (int i = 0, to; i<SZ; i++) { to = sam[top].trans[i]; if (!~to) { continue; } if (i ^ SZ-1) // 不能用带 ":" 的弧转移过来 { sam[to].validnum += sam[top].validnum; } if (!--tindeg[to]) { stk.push(to); } } } } ll kk() { int u = newnode(0,0,0,-1), i = 1; while(s[i]) { u = insert(s[i], u); ++i; } topsort1(); return topsort(); } int main() { #ifdef LOCAL freopen("d:\data.in", "r", stdin); // freopen("d:\my.out", "w", stdout); #endif ll ans = 0; int kase, i = 1, len; scanf("%d", &kase); while(kase--) { scanf("%s", s+i); len = strlen(s+i); if (kase) { s[i+len] = ':'; } i += len+1; } printf("%lld", kk()); return 0; }
ac情况
Accepted
参考
【1】《史上全网最清晰后缀自动机学习 (一) 基本概念入门》
【2】《史上全网最清晰后缀自动机学习 (二) 后缀自动机的线性时间构造算法》
【3】《史上全网最清晰后缀自动机学习 (三) 后缀自动机里的树结构》