阿狸的打字机


题目描述

题解

  • 原题意简化为:询问 $x$ 串在 $y$ 串中的出现次数,多组询问

  • 输入多个串的方式构成一棵 $trie$ 树,加入一个字符相当于从当前节点往下走一层,$P$ 标记当前节点为一个串的结尾 $B$ 返回父亲节点。不只应用于 $trie$ 树,这种输入 $dfs$ 序方法建树的套路,遇到不止一次了

  • $trie$ 树已经建好,又是字符串匹配问题,自然想到建出 $AC$ 自动机

  • 考虑 $fail$ 树的意义:父亲节点必然是儿子节点的一个极大后缀($kmp$ 失配数组同理)

  • 那么对于 $y$ 串,我们一个个字符添加。如果加入一个字符后,$x$ 刚好出现,说明 $x$ 是 $y$ 的一个后缀。那么 $x$ 必然是 $y$ 在 $fail$ 树上的一个祖先。(不一定是极大后缀,不一定是父亲)

  • 问题转化为,从根到 $y$ 结尾,对应在 $tire$ 树上的每个节点,有多少个节点能直接或间接地,通过爬 $fail$ 树,达到 $x$ 末尾节点。考虑用数据结构维护

  • 考虑逆问题,在 $x$ 的 $fail$ 子树内,有多少个结点是 $y$ 串在 $trie$ 树的节点

  • 我们知道子树问题可以转化为区间问题,对于节点 $x$,我们询问在 $fail$ 树上,入栈序和出栈序的区间 $[in[x], out[x]]$

  • 在线处理询问,破坏了已经建成的 $trie$ 树的顺序和完整性,重复插入字符时间空间无法接受,考虑离线处理,按建 $trie$ 树每个串出现的顺序排序。算法呼之欲出了

  1. 建出 $trie$、$fail$ 树,遍历 $fail$ 树得 $dfs$ 序
  2. 对询问按 $y$ 从小到大排序
  3. 再次按输入顺序遍历 $trie$ 树,进入、退出一个节点在 $fail$ 树上对应 $dfs$ 序节点 $+1$、$-1$
  4. 遇到 $P$,处理对当前 $y$ 的所有询问,用树状数组统计每个 $x$ 在 $fail$ 树上的子树和

总结

$AC$ 自动机的中难题,需要有非常深入的理解,理清楚 $fail$ 树、$trie$ 树之间的关系。并想到询问离线处理,有一定数据结构基础

代码

cpp
#include <bits/stdc++.h>
using namespace std;

const int N = 2e5 + 50;

int len, ans[N], in[N], out[N], dfn;

struct Graph
{
    int etop, head[N];
    Graph () {memset(head, -1, sizeof(head));}
    struct Edge
    {
        int v, nxt;
    }e[N];

    void add(int u, int v)
    {
        e[++etop].v = v;
        e[etop].nxt = head[u];
        head[u] = etop;
    }
    void dfs(int u)
    {
        in[u] = ++dfn;
        for (int i = head[u]; ~i; i = e[i].nxt)
            dfs(e[i].v);
        out[u] = dfn;
    }
}G;

namespace BIT
{
    int a[N];
    void add(int pos, int val)
    {
        for (int i = pos; i <= dfn; i += i & -i)
            a[i] += val;
    }
    int query(int pos)
    {
        int ret = 0;
        for (int i = pos; i; i -= i & -i)
            ret += a[i];
        return ret;
    }
    int query(int l, int r)
    {
        return query(r) - query(l - 1);
    }
}

struct Query
{
    int x, y, id;
    void input(int _id)
    {
        id = _id;
        scanf("%d%d", &x, &y);
    }
    bool operator <(const Query &t) const
    {
        return y < t.y;
    }
}ask[N];

namespace AC
{
    int ch[N][26], top, fa[N], fail[N], num, pos[N];

    void trie(char *s)
    {
        int u = 0;
        for (int v, i = 0; i < len; i++)
        {
            if (s[i] == 'P')
                pos[++num] = u;
            else if (s[i] == 'B')
                u = fa[u];
            else
            {
                v = s[i] - 'a';
                if (!ch[u][v])
                    ch[u][v] = ++top;
                fa[ch[u][v]] = u;
                u = ch[u][v];
            }
        }
    }

    void build()
    {
        queue <int> q;
        for (int i = 0; i < 26; i++)
            if (ch[0][i])
            {
                q.push(ch[0][i]);
                G.add(0, ch[0][i]);
            }
        while (!q.empty())
        {
            int u = q.front();
            q.pop();
            for (int v, i = 0; i < 26; i++)
            {
                v = ch[u][i];
                if (v)
                {
                    fail[v] = ch[fail[u]][i];
                    q.push(v);
                    G.add(fail[v], v);
                }
                else ch[u][i] = ch[fail[u]][i];
            }
        }
        G.dfs(0);
    }

    void work(char *s)
    {
        int u = 0, cnt = 0, k = 1;
        for (int i = 0; i < len; i++)
        {
            if (s[i] == 'B')
            {
                BIT::add(in[u], -1);
                u = fa[u];
            }
            else if (s[i] == 'P')
            {
                ++cnt;
                while (ask[k].y == cnt)
                {
                    int x = pos[ask[k].x];
                    ans[ask[k].id] = BIT::query(in[x], out[x]);
                    ++k;
                }
            }
            else 
            {
                u = ch[u][s[i] - 'a'];
                BIT::add(in[u], 1);
            }
        }
    }
}

char s[N];

int main()
{
    scanf("%s", s);
    len = strlen(s);
    AC::trie(s);
    AC::build();
    int m;
    scanf("%d", &m);
    for (int i = 1; i <= m; i++)
        ask[i].input(i);
    sort(ask + 1, ask + m + 1);
    AC::work(s);
    for (int i = 1; i <= m; i++)
        printf("%d\n", ans[i]);
    return 0;
}

文章作者: gtxygyzb
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 gtxygyzb !