大數翻倍法求解CRT

註:做法和思路是 zhx 在一次講課中提出的,如有侵權,請聯繫作者刪除

其實別的題解也有提到過暴力做法,但這裡將會給出更加嚴謹的複雜度的證明

正文

引入

我們知道,中國剩餘定理是一種用來求解類似於

\[\begin{cases}
x \equiv a_1 \pmod {m_1} \\
x \equiv a_2 \pmod {m_2} \\
x \equiv a_3 \pmod {m_3} \\
… \\
x \equiv a_4 \pmod {m_4} \\
\end{cases}
\]

形式的同餘方程組的定理,要求我們找出 \(x\) 的最小非負整數解

大數翻倍法

現在市面上比較推廣的一種方法是用擴展歐幾里得來求解同餘方程組。

這裡將介紹一種更為暴力的演算法——大數翻倍法,寫起來也更加方便簡潔。

先來考慮兩個同餘方程的情況:

\[\begin{cases}
x \equiv a_1 \pmod {m_1} \\
x \equiv a_2 \pmod {m_2}
\end{cases}
\]

考慮用一種暴力的方法將其合併成一個同餘方程。讓我們設初始的 \(x = 0, m = 1\),合併了第一個方程後變為 \(x = a_1, m = m_1\)

那麼現在只需要滿足第二個同餘方程即可。我們知道 \((a_1 + km_1) \mod m_1 = a_1\),一個顯然的想法是每次暴力的加 \(m_1\),然後暴力的判斷能否滿足第二個同餘方程。找到一個能滿足的情況合併即可,模數合併為 $ \operatorname{lcm}(m_1,m_2)$,程式碼也十分好寫,只有四行:

void Merge(LL &a1, LL &m1, LL a2, LL m2) {
    while(a1 % m2 != a2) a1 += m1;
    m1 = Lcm(m1, m2);
}

複雜度證明

根據費馬小定理我們知道 \(a^{p-1} \equiv 1 \pmod p\) ,

又因為 \(a^0 \equiv 1 \pmod p\) ,所以得到

\[a^{p-1} \equiv a^0 \pmod p
\]

將其推廣就會有:

\[a^{x + p-1} \equiv a^x \pmod p
\]

這說明了什麼?

\(a^x\) 在模 \(p\) 下的循環節,在最壞情況下只有 \(p-1\) 大小。

所以上面程式碼每次合併的複雜度是 \(O(m_2)\) 的。發現更小的模數的複雜度更優,所以我們添一句優化,通過特判轉換一下枚舉的模數即可。程式碼改為:

void Merge(LL &a1, LL &m1, LL a2, LL m2) {
    if(m1 < m2) swap(m1, m2), swap(a1, a2);
    while(a1 % m2 != a2) a1 += m1;
    m1 = Lcm(m1, m2);
}

所以總的複雜度為 \(O(\sum_{i=1}^{n}m_i)\)

但是!它的複雜度真的有那麼高嗎?(那我也沒必要寫這篇部落格了是吧

我們知道答案一定在 long long 範圍內,並且 \(\prod_{i=1}^{n} m_i\) 一定也不會爆 long long

因為高精度求解同餘方程組也沒那個做法是吧,出題人也一定不會出個爆 long long 的樣例,因為他自己也做不了。

讓我們來考慮最壞情況:

想要卡我們,每個模數都得是一個大質數。還要保證成績和在 long long 範圍內(也就是 \(10^{18}\))。

那麼只有一種情況, \(n = 2\)!此時 \(m_i\) 可以做到 \(2 \times 10^9\) 級別的大質數。總時間複雜度為 \(O(10^9)\) ,可以被卡。

但是,當 \(n = 3\) 時, \(m_i\) 只有 \(10^6\) 級別,我們的複雜度也只有 \(O(3 \times 10^6)\) ,可以通過。

\(n\) 更大的情況就不必說了吧。

大數翻倍法的優勢

  • 碼量小
  • 理解難度小
  • 一般不會被卡,沒有人會對著這個非主流演算法卡十個點的
  • 不需要考慮模數互質的情況

最後的最後:上程式碼!

/*
Work by: Suzt_ilymics
Problem: 不知名屑題
Knowledge: 大數翻倍法
Time: O(能過)
*/
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define LL long long
#define orz cout<<"lkp AK IOI!"<<endl
using namespace std;
LL n, v, d, a, b;
LL read(){
    LL s = 0, f = 0;
    char ch = getchar();
    while(!isdigit(ch))  f |= (ch == '-'), ch = getchar();
    while(isdigit(ch)) s = (s << 1) + (s << 3) + ch - '0' , ch = getchar();
    return f ? -s : s;
}

LL Gcd(LL x, LL y) { return !y ? x : Gcd(y, x % y); }
LL Lcm(LL x, LL y) { return x / Gcd(x, y) * y; }
void Merge(LL &a1, LL &m1, LL a2, LL m2) {
    if(m1 < m2) swap(m1, m2), swap(a1, a2);
    while(a1 % m2 != a2) a1 += m1;
    m1 = Lcm(m1, m2);
}

int main()
{
    n = read(); v = 0, d = 1; // 初始化 
    for(int i = 1; i <= n; ++i) a = read(), b = read(), b %= a, Merge(v, d, b, a);
    printf("%lld", v);
    return 0;
}

如果覺得寫的不錯就點個贊吧這個做法頂上去吧/kel