算法竞赛进阶指南-54.后缀数组

题目链接

算法竞赛进阶指南-54.后缀数组

后缀数组 (SA) 是一种重要的数据结构,通常使用倍增或者 DC3 算法实现,这超出了我们的讨论范围。

在本题中,我们希望使用快排、Hash 与二分实现一个简单的 O(nlog2n)O(nlog^2n) 的后缀数组求法。

详细地说,给定一个长度为 nn 的字符串 SS(下标 0n10 \sim n-1),我们可以用整数 kk(0k<n0 \le k < n) 表示字符串 SS 的后缀 S(kn1)S(k \sim n-1)

把字符串 SS 的所有后缀按照字典序排列,排名为 ii 的后缀记为 SA[i]SA[i]

额外地,我们考虑排名为 ii 的后缀与排名为 i1i-1 的后缀,把二者的最长公共前缀的长度记为 Height[i]Height[i]

我们的任务就是求出 SASAHeightHeight 这两个数组。

输入格式

输入一个字符串,其长度不超过 3030 万。

字符串由小写字母构成。

输出格式

第一行为数组 SASA,相邻两个整数用 11 个空格隔开。

第二行为数组 HeightHeight,相邻两个整数用 11 个空格隔开,我们规定 Height[1]=0Height[1]=0

输入样例:

ponoiiipoi

输出样例:

9 4 5 6 2 8 3 1 7 0
0 1 2 1 0 0 2 1 0 2

Method : 快排+二分+字符串前缀哈希

首先搞明白题目意思:题目样例中的10个后缀

0 ponoiiipoi
1 onoiiipoi
2 noiiipoi
3 oiiipoi
4 iiipoi
5 iipoi
6 ipoi
7 poi
8 oi
9 i

输出:
(1) 字典序排序后的下标
(2) 排序后,每两个相邻后缀的最长公共前缀长度

然后依照题意,要以字典序对这些后缀排序:

快排本身O(nlogn)O(n\log n),然后如果每次扫描字符串按字典序比较O(n)O(n),总时间复杂度会达到O(n2logn)O(n^2\log n)

我们可以考虑使用二分+字符串前缀哈希,将字典序比较优化成O(logn)O(\log n)
在快速排序比较后缀p和后缀q时,可以使用二分法O(logn)O(\log n),比较S[pp+mid1]S[p \sim p + mid - 1]S[qq+mid1]S[q \sim q + mid - 1]的字符串前缀哈希值是否相等,从而求得后缀p和后缀q的最长公共前缀长度lenlen,然后再比较两个字符串的下一个字符O(1)O(1)(也就是第一个不相等的位置)S[p+len]S[p + len]S[q+len]S[q + len],就能确定这两个字符串的字典序大小了。

然后就是喜闻乐见的日常cin>>字符串被卡,要开O2优化才能过。

#pragma GCC optimize(2)

#include <string>
#include <vector>
#include <iostream>
#include <algorithm>
#include <limits.h>

using namespace std;

typedef unsigned long long ULL;

const int N = 3e5 + 7;
const int P = 131;

string s;            // 字符串
int n ;              // s的长度
vector<ULL> h;       // hash
vector<int> sa(N);   // 后缀按字典序排序后的下标

// 字符串hash
vector<ULL> p(N);
vector<ULL> get_prefix_hash(const string& s) {
    int n = s.size();
    vector<ULL> h(n + 1);
    h[0] = 0;
    p[0] = 1;
    for (int i = 1; i <= n; i ++) {
        h[i] = h[i - 1] * P + s[i - 1];
        p[i] = p[i - 1] * P;
    }

    return h;
}

ULL get_substr_hash(const vector<ULL>& h, int l, int r) {
    return h[r] - h[l - 1] * p[r - l + 1];
}

int get_max_common_prefix(int a, int b) {
    int l = 0, r = min(n - a + 1, n - b + 1);
    while (l != r + 1) {
        int mid = l + r >> 1;
        if (get_substr_hash(h, a, a + mid - 1) != get_substr_hash(h, b, b + mid - 1)) {
            // 前缀不等,缩小len
            r = mid - 1;
        }
        else {
            // 前缀相等,扩大len
            l = mid + 1;
        }
    }

    return r;
}

bool cmp(int a, int b) {
    int len = get_max_common_prefix(a, b);
    int a_nxt = a + len > n ? INT_MIN : s[a + len - 1];
    int b_nxt = b + len > n ? INT_MIN : s[b + len - 1];

    return a_nxt < b_nxt;
}

void quick_sort(vector<int>& arr, int l, int r) {
    if (l >= r) return;

    int x = arr[l + (r - l) / 2];

    int i = l - 1, j = r + 1;
    while (i < j) {
        do i ++; while (cmp(arr[i], x) == 1);
        do j --; while (cmp(x, arr[j]) == 1);
        if (i < j) swap(arr[i], arr[j]);
    }

    quick_sort(arr, l, j);
    quick_sort(arr, j + 1, r);
}



int main() {
    cin >> s;
    n = s.size();
    h = get_prefix_hash(s);

    // 初始化sa
    for (int i = 1; i <= n; i ++) {
        sa[i] = i;
    }

    quick_sort(sa, 1, n);
    // sort(sa.begin() + 1, sa.begin() + 1 + n, cmp);

    for (int i = 1; i <= n; i ++) {
        printf("%d ", sa[i] - 1);
    }
    puts("");
    for (int i = 1; i <= n; i ++) {
        if (i == 1) printf("0 ");
        else printf("%d ", get_max_common_prefix(sa[i - 1], sa[i]));
    }
    puts("");

    return 0;
}

复杂度分析

  • 时间复杂度:O(nlog2n){O(n\log^2 n)}, 其中nlognn\log n为快排,logn\log n为字典序比较(二分+字符串hash)。

  • 空间复杂度:O(N){O(N)}