大数翻倍法求解CRT

注:做法和思路是 zhx 在一次讲课中提出的,如有侵权,请联系作者删除

其实别的题解也有提到过暴力做法,但这里将会给出更加严谨的复杂度的证明

正文

引入

我们知道,中国剩余定理是一种用来求解类似于

\[\begin{cases}
x \equiv a_1 \pmod {m_1} \\
x \equiv a_2 \pmod {m_2} \\
x \equiv a_3 \pmod {m_3} \\
… \\
x \equiv a_4 \pmod {m_4} \\
\end{cases}
\]

形式的同余方程组的定理,要求我们找出 \(x\) 的最小非负整数解

大数翻倍法

现在市面上比较推广的一种方法是用扩展欧几里得来求解同余方程组。

这里将介绍一种更为暴力的算法——大数翻倍法,写起来也更加方便简洁。

先来考虑两个同余方程的情况:

\[\begin{cases}
x \equiv a_1 \pmod {m_1} \\
x \equiv a_2 \pmod {m_2}
\end{cases}
\]

考虑用一种暴力的方法将其合并成一个同余方程。让我们设初始的 \(x = 0, m = 1\),合并了第一个方程后变为 \(x = a_1, m = m_1\)

那么现在只需要满足第二个同余方程即可。我们知道 \((a_1 + km_1) \mod m_1 = a_1\),一个显然的想法是每次暴力的加 \(m_1\),然后暴力的判断能否满足第二个同余方程。找到一个能满足的情况合并即可,模数合并为 $ \operatorname{lcm}(m_1,m_2)$,代码也十分好写,只有四行:

void Merge(LL &a1, LL &m1, LL a2, LL m2) {
    while(a1 % m2 != a2) a1 += m1;
    m1 = Lcm(m1, m2);
}

复杂度证明

根据费马小定理我们知道 \(a^{p-1} \equiv 1 \pmod p\) ,

又因为 \(a^0 \equiv 1 \pmod p\) ,所以得到

\[a^{p-1} \equiv a^0 \pmod p
\]

将其推广就会有:

\[a^{x + p-1} \equiv a^x \pmod p
\]

这说明了什么?

\(a^x\) 在模 \(p\) 下的循环节,在最坏情况下只有 \(p-1\) 大小。

所以上面代码每次合并的复杂度是 \(O(m_2)\) 的。发现更小的模数的复杂度更优,所以我们添一句优化,通过特判转换一下枚举的模数即可。代码改为:

void Merge(LL &a1, LL &m1, LL a2, LL m2) {
    if(m1 < m2) swap(m1, m2), swap(a1, a2);
    while(a1 % m2 != a2) a1 += m1;
    m1 = Lcm(m1, m2);
}

所以总的复杂度为 \(O(\sum_{i=1}^{n}m_i)\)

但是!它的复杂度真的有那么高吗?(那我也没必要写这篇博客了是吧

我们知道答案一定在 long long 范围内,并且 \(\prod_{i=1}^{n} m_i\) 一定也不会爆 long long

因为高精度求解同余方程组也没那个做法是吧,出题人也一定不会出个爆 long long 的样例,因为他自己也做不了。

让我们来考虑最坏情况:

想要卡我们,每个模数都得是一个大质数。还要保证成绩和在 long long 范围内(也就是 \(10^{18}\))。

那么只有一种情况, \(n = 2\)!此时 \(m_i\) 可以做到 \(2 \times 10^9\) 级别的大质数。总时间复杂度为 \(O(10^9)\) ,可以被卡。

但是,当 \(n = 3\) 时, \(m_i\) 只有 \(10^6\) 级别,我们的复杂度也只有 \(O(3 \times 10^6)\) ,可以通过。

\(n\) 更大的情况就不必说了吧。

大数翻倍法的优势

  • 码量小
  • 理解难度小
  • 一般不会被卡,没有人会对着这个非主流算法卡十个点的
  • 不需要考虑模数互质的情况

最后的最后:上代码!

/*
Work by: Suzt_ilymics
Problem: 不知名屑题
Knowledge: 大数翻倍法
Time: O(能过)
*/
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define LL long long
#define orz cout<<"lkp AK IOI!"<<endl
using namespace std;
LL n, v, d, a, b;
LL read(){
    LL s = 0, f = 0;
    char ch = getchar();
    while(!isdigit(ch))  f |= (ch == '-'), ch = getchar();
    while(isdigit(ch)) s = (s << 1) + (s << 3) + ch - '0' , ch = getchar();
    return f ? -s : s;
}

LL Gcd(LL x, LL y) { return !y ? x : Gcd(y, x % y); }
LL Lcm(LL x, LL y) { return x / Gcd(x, y) * y; }
void Merge(LL &a1, LL &m1, LL a2, LL m2) {
    if(m1 < m2) swap(m1, m2), swap(a1, a2);
    while(a1 % m2 != a2) a1 += m1;
    m1 = Lcm(m1, m2);
}

int main()
{
    n = read(); v = 0, d = 1; // 初始化 
    for(int i = 1; i <= n; ++i) a = read(), b = read(), b %= a, Merge(v, d, b, a);
    printf("%lld", v);
    return 0;
}

如果觉得写的不错就点个赞吧这个做法顶上去吧/kel