【演算法】莫隊演算法粗略講解

  • 2020 年 3 月 16 日
  • 筆記

原文鏈接:https://cnblogs.com/ctjcalc/p/post8.html

莫隊是由莫濤大佬研究出的一種非常優秀的離線演算法,主要用來解決區間問題,甚至可以在非強制在線的情況下離線處理包括修改的操作。本文主要講解普通莫隊演算法。

前置知識

  • 分塊思想
  • 熟練掌握 STL 演算法

莫隊思想

先來看一道例子,給出下面的一個序列,給出一個區間,求區間和:

下標 1 2 3 4 5 6
(A) 2 1 1 3 4 3

如果這道題不涉及修改操作,大家都會想到預處理前綴和。現在我強制要求使用莫隊,要怎麼做呢?假設我們要求 (S_{[2,4]}) 的和,並且知道了 (S_{[2,3]}),可以想到用 (S_{[2,3]})加上 (A_4),這樣就能夠得到答案了。同理,如果我們知道了 (S_{[3,4]})(S_{[2,4]} = S_{[3,4]} + A_2)。這些操作實際上就是移動區間的一個端點,同時把新的端點的答案合併。莫隊就是這樣,如果已經知道了一個區間的答案,就試圖通過移動區間邊界,把原來的答案與新的邊界的值合併,最終讓當前區間與詢問的區間重合,然後記錄答案。

但是,現在又有一個問題了,如果有一個序列,長度為 (n),詢問 (m) 次,按照 ([1,2],[n-1,n],[3,4],[n-3,n-2],cdots) 詢問,那時間複雜度就變成了 (O(nm)) 級別了,顯然不利於解題。於是莫濤大佬又想出了一個解決方案,把序列分塊,再把詢問排序,按照排序後的詢問區間處理答案,對於兩個詢問區間,如果兩個區間的左端點在同一個塊,就比較它們的右端點,否則比較左端點。對於上面的那些詢問,排序後,時間複雜度減小至 (Theta(n+m))。對於一般情況下,這樣的時間複雜度是 (O(nsqrt{n}))但我不會證明 QwQ。

程式碼模板

通過上面的介紹,我們來系統地總結一下移動區間端點的 (4) 種情況(上面只舉例了兩種)。記當前區間為 ([l,r]),詢問區間為 ([L,R]),用程式碼說明。

情況 程式碼實現
(L < l) while (L < l) add(--l);
(R > r) while (R > r) add(++r);
(L > l) while (L > l) sub(l++);
(R < r) while (R < r) sub(r--);

前兩種情況是擴大區間,就把新端點的答案加到總答案里,後兩種相反,把舊端點的答案從總答案里刪除。

接下來,看看莫隊演算法的框架。

// ...  struct query {      int l, r, id; // 詢問區間以及它是第幾個詢問  };    constexpr int maxn = /* ... */; // 序列長度  constexpr int maxq = /* ... */; // 詢問個數  query qs[maxq];  int arr[maxn], ans[maxq], n, q, res, blocksize; // res 表示臨時計算的答案,會隨著區間的移動不斷更新    inline int blockid(int x) { return (x - 1) / blocksize + 1; }    inline void add(int x) {      // ...  }    inline void sub(int x) {      // ...  }    void solve() {      // ...      blocksize = sqrt(n); // 計算塊的大小      sort(qs + 1, qs + 1 + q, [](const query &a, const query &b) { // 排序區間          return blockid(a.l) == blockid(b.l) ? a.r < b.r : a.l < b.l;      }); // 這裡使用了 Lambda 表達式      int l = 1, r = 0; // 一開始區間是空的,這樣寫是為了符合語義      for (int i = 1; i <= q; ++i) { // 處理詢問          while (qs[i].l < l)              add(--l);          while (qs[i].r > r)              add(++r);          while (qs[i].l > l)              sub(l++);          while (qs[i].r < r)              sub(r--);          ans[qs[i].id] = res; // 記錄答案      }      // ...  }

演算法優化

排序後的區間順序其實對處理的速度有較大的影響,實際上,我們一般採用一個叫做奇偶化排序的東西。如果塊的編號是奇數,就按右端點升序排序,否則按降序排序。它的原理大致是——如果處理完奇數塊,(r) 這個指針就可以不用再跑更遠,從前開始往後掃描,可以看看下面這個詢問的例子:

// blocksize = 3  1 1  2 25  3 50  4 1  5 25  6 50

用了奇偶化排序後,速度可以提升大約 (30%)

程式碼實現:

sort(qs + 1, qs + 1 + q, [](const query &a, const query &b) {      return blockid(a.l) == blockid(b.l)                 ? (a.r == b.r ? 0 : !((blockid(a.l) & 1) && (a.r < b.r)))                 : a.l < b.l;  });

例題講解

SPOJ 3267 D-query

給出一個長度為 (n) 的序列,進行 (q) 次詢問,每次詢問給出區間 ([l,r]),問區間去重後有多少個數?

唯一需要注意的就是多維護一個數組,記為 cnt[i],表示第值為 (i) 的數目前出現了多少次。很顯然,直接在 add(x)sub(x) 里更新,看程式碼。

#include <bits/stdc++.h>  using namespace std;    template <typename T> T read() {      T x = 0, s = 1;      char c = getchar();      while (c < '0' || '9' < c) {          if (c == '-')              s = -1;          c = getchar();      }      while ('0' <= c && c <= '9') {          x = (x << 1) + (x << 3) + (c ^ 48);          c = getchar();      }      return x * s;  }    struct query {      int l, r, id;  };    constexpr int maxn = 30000 + 5;  constexpr int maxv = 1e6 + 5;  constexpr int maxq = 200000 + 5;  query qs[maxq];  int arr[maxn], cnt[maxv], ans[maxq], n, q, res, blocksize;    inline int blockid(int x) { return (x - 1) / blocksize + 1; }    inline void add(int x) {      ++cnt[arr[x]];      if (cnt[arr[x]] == 1)          ++res;  }    inline void sub(int x) {      --cnt[arr[x]];      if (!cnt[arr[x]])          --res;  }    int main() {  #ifndef ONLINE_JUDGE      freopen("Environment/project.in", "r", stdin);      freopen("Environment/project.out", "w", stdout);  #endif      n = read<int>();      blocksize = sqrt(n);      for (int i = 1; i <= n; ++i)          arr[i] = read<int>();      q = read<int>();      for (int i = 1; i <= q; ++i) {          qs[i].l = read<int>();          qs[i].r = read<int>();          qs[i].id = i;      }      sort(qs + 1, qs + 1 + q, [](const query &a, const query &b) {          return blockid(a.l) == blockid(b.l)                     ? (a.r == b.r ? 0 : (blockid(a.l) & 1) ^ (a.r < b.r))                     : a.l < b.l;      });      int l = 1, r = 0;      for (int i = 1; i <= q; ++i) {          while (qs[i].l < l)              add(--l);          while (qs[i].r > r)              add(++r);          while (qs[i].l > l)              sub(l++);          while (qs[i].r < r)              sub(r--);          ans[qs[i].id] = res;      }      for (int i = 1; i <= q; ++i)          printf("%dn", ans[i]);      return 0;  }

Luogu P1494 小Z的襪子

有一個長度為 (n) 的序列,進行 (q) 次詢問,每次詢問給出區間 ([l,r]),問隨機從區間選兩個數相等的概率是多少?使用最簡分數輸出。

同樣是要維護 cnt[i]。對於區間 ([l,r]),選取兩個數的方案數是 (binom{r-l+1}{2}),而選取到相同兩個數的方案數是 (sum_{i=1}^{N}binom{cnt[i]}{2})。概率就是它們相除。我們可以維護一下後面的這個方案數,前一個直接計算。每次移動區間端點都是差不多的,以擴大區間為例,要先減去之前的答案,再加上 cnt[i] + 1 的答案,那麼加在一起就是 (binom{cnt[i]+1}{2}-binom{cnt[i]}{2})。這個式子是可以化簡的:
[ begin{aligned}binom{cnt[i]+1}{2}-binom{cnt[i]}{2}&=frac{(cnt[i]+1)!}{2!(cnt[i]-1)!}-frac{cnt[i]}{2!(cnt[i]-2)}\&=frac{cnt[i](cnt[i]+1)}{2}-frac{cnt[i](cnt[i]-1)}{2}\&=frac{2cnt[i]}{2}\&=cnt[i]end{aligned} ]
縮小區間也是一樣的,你可以再推導一遍,也可以先把 cnt[i] 減去 (1),然後讓臨時答案減去 cnt[i],可以想想為什麼。

#include <bits/stdc++.h>  using namespace std;    template <typename T> T read() {      T x = 0, s = 1;      char c = getchar();      while (c < '0' || '9' < c) {          if (c == '-')              s = -1;          c = getchar();      }      while ('0' <= c && c <= '9') {          x = (x << 1) + (x << 3) + (c ^ 48);          c = getchar();      }      return x * s;  }    struct query {      int l, r, id;  };    constexpr int maxn = 50000 + 5;  query qs[maxn];  int arr[maxn], cnt[maxn], n, q, blocksize;  long long ans[maxn][2], res;    inline int blockid(int x) { return (x - 1) / blocksize + 1; }    inline void add(int x) { res += cnt[arr[x]]++; }    inline void sub(int x) { res -= --cnt[arr[x]]; }    long long GCD(long long x, long long y) { return y == 0 ? x : GCD(y, x % y); }    int main() {  #ifndef ONLINE_JUDGE      freopen("Environment/project.in", "r", stdin);      freopen("Environment/project.out", "w", stdout);  #endif      n = read<int>();      q = read<int>();      blocksize = sqrt(n);      for (int i = 1; i <= n; ++i)          arr[i] = read<int>();      for (int i = 1; i <= q; ++i) {          qs[i].l = read<int>();          qs[i].r = read<int>();          qs[i].id = i;      }      sort(qs + 1, qs + 1 + q, [](const query &a, const query &b) {          return blockid(a.l) == blockid(b.l)                     ? (a.r == b.r ? 0 : (blockid(a.l) & 1) ^ (a.r < b.r))                     : a.l < b.l;      });      int l = 1, r = 0;      for (int i = 1; i <= q; ++i) {          if (qs[i].l == qs[i].r) {              ans[qs[i].id][0] = 0;              ans[qs[i].id][1] = 1;              continue;          }          while (qs[i].l < l)              add(--l);          while (qs[i].r > r)              add(++r);          while (qs[i].l > l)              sub(l++);          while (qs[i].r < r)              sub(r--);          ans[qs[i].id][0] = res;          ans[qs[i].id][1] = (1LL * (r - l) * (r - l + 1)) >> 1;      }      for (int i = 1; i <= q; ++i) {          long long g = GCD(ans[i][0], ans[i][1]);          printf("%lld/%lldn", ans[i][0] / g, ans[i][1] / g);      }      return 0;  }