樹狀數組

樹狀數組

1. 緒論

樹狀數組本質上是一個 運用了 分塊思想前綴和 數組,使得查詢和修改的時間複雜度都是\(O(logN)\) 級別,但由於是一個前綴和數組,所以對於一些區間能做的事情還是十分有限,鑒於樹狀數組的實現簡單,程式碼量少,對於問題是否使用樹狀數組還是線段樹的情況需要自行辨別

2. 模板

#define lowbit(a) a&(-a)
int Trarr[100005];
void add(int x, int val){
    while(x <= MAX_LIMIT){
        Trarr[x] += val;
        x += lowbit(x);
    }
}
int getsum(int x){
    int ans = 0;
    while(x){
        ans += Trarr[x];
        x -= lowbit(x);
    }
    return ans;
}

3. 經典模型

  • PUIQ問題 (單點修改及區間求和) \(OptA\) . \(a_i += P\) \(OptB\) . \(\sum^i_{[L,R]} a_i\)

    考慮點 \(i\) 的值發生變化時會影響到樹狀數組的哪些部分,然後進行更新就可以了,一般是在點 \(i\) 加一個負值去消除原數的影響再加入新數,並更新其前繼數組,區間求和就和前綴和一樣,直接 \(getsum(R) – getsum(L-1)\)

    見 Luogu P3374

  • IUPQ問題 (區間修改和單點查詢) \(Opt\) \(a_i += P,\quad i\in[L, R]\) \(Quiry\quad a_[i]\)

    考慮結合差分演算法,區間修改是差分的演算法優勢,對於單點求值就是差分的前綴和,此時用樹狀數組來維護差分的前綴和數組就可以解決,對於區間修改直接\(Add(L, V)\quad Add(R+1, -V)\) 單點求和就相當於查前綴和數組 不再贅述

    見 Luogu P3368

  • 求逆序數問題 \(Count \quad i < j , a_i > a_j\) 如 $1,5,2,4,3 $ 中的\((5,2)\)是一對逆序對

    核心思想是記錄 在樹狀數組中記錄值為 \([1,a_i]\) 的個數注意不是單純的將數據的值作為樹狀數組的下標,將樹狀數組當一個桶用,樹狀數組本質是一個前綴和數組,這樣是行不通的,此時的\(Add(a_i, 1)\) 操作就是記錄值在\([1,a_i]\) 的數又多一個, 計算逆序數時就一邊遍歷,一邊將當前的數計入樹狀數組,然後詢問比自己小的數有多少個 即詢問值在\([1, a_i]\)的數有多少個,此時得到的是 \(i < j, ai < aj\) 的數,再用當前的 \(i\) 減去這個值就能得到這個位置的逆序數

    見 Luogu P1966

    統計個數的思想十分常用,有時也將樹狀數組數組全初始化為1或0,來查找第k個此時的值,結合下面的二分

  • 二分思想 就是在查找的過程中進行二分枚舉

    比如你每次可以往容器中丟入一個編號為 \(i\) 的球,或者詢問容器中編號第 \(K\) 大的球的編號,顯然對於編號第 \(K\) 的球,前面只能有 \(K-1\) 個球比他大, 那就有\(getsum(N) – getsum(x) <= K-1\) ,同時由於前綴和的特殊性,樹狀數組中可能存在多個該情況的數字,此時取第一個數字,由於前綴和必然滿足單調性,所以此時可以使用二分來查找滿足這個條件的數字

  • 區間排序思想 有時為了維護區間或者比較多個區間內的資訊,會固定一端進行排序

    如給定 \(N\) 個區間,計算第 \(i\) 個區間有多少個區間大於他,此時可以將給定的區間按右端點從大到小排序,右端點相同的則左端點小的排前面,用樹狀數組維護左端點的插值,即讀入第 \(i\) 個區間時,就進行 \(Add(L_i, 1)\),表示\([1,L_i]\) 這一段中又插入了一個區間,查詢時直接詢問這一段有多少個區間就可以,因為我們進行了排序,所以保證右端點是遞減的。

    還有如給定長度為 \(N\) 的一個序列,有 \(Q\) 個詢問對於某段區間內的數字有多少種,此時也可進行區間排序,然後一邊遍歷一邊維護,考慮對於第二次出現的數字,在樹狀數組中消除他在第一次出現的影響並在第二次出現的位置更新它,即我們此時在樹狀數組中維護的資訊是 \([1, R]\)中有多少種數,並且保證了出現的數字不重複計算,且始終將其更新到最右端 。

    最後計算答案時只要詢問 \(getsum(quiry[i].R) – getsum(quiry[i].L-1)\)

    見 Luogu P4113

  • 多維樹狀數組 有時一維的樹狀數組無法維護全部資訊,此時選擇多維記錄狀態

    如一個 \(N*M\) 的方格,每個格子初始時有一個整數權值,接下來有兩種操作:

    改變一個格子里權值或者求一個子矩陣中特定權值出現的個數,此時我們選擇開一個三維的樹狀數組,前兩維表示位置,第三維表示數字(類逆序數模型),維護時就用兩個for循環去控制位置,然後再對應數字位置上\(+val\) 如下面程式碼

    #define lowbit(a) a&(-a)
    void Tradd(int x, int y, int val, int color){
        for(int i = x; i <= N; i += lowbit(i)){
            for(int j = y; j <= M; j += lowbit(j)){
                Trarr[i][j][color] += val;
            }
        }
    }
    int Trqui(int x, int y, int color){
        int ans = 0;
        for(int i = x; i; i -= lowbit(i)){
            for(int j = y; j; j -= lowbit(j)){
                ans += Trarr[i][j][color];
            }
        }
        return ans;
    }
    

詢問的時候就類似二維前綴和

\(Trqui(x2, y2, color) – Trqui(x1-1, y2, color) – Trqui(x2, y1-1, color) + Trqui(x1-1, y1-1, color)\)

不明白的讀者可以嘗試畫一個圖就能明白

這題見Luogu P4054

4. 大應用題

這裡提供一道比較麻煩的例題 Luogu P3960

大概的一個思路每行單獨維護,因為互不影響,最後一列比較特殊也需要單獨維護,然後離線問題,將問題按照行從小到大排列,然後預處理每列的真實刪除位置,比如對於\((1,3)\) 是刪除第一行第三列數字,第二次刪除時就不再是\((1,3)\) 了,考慮其先向左對齊,所以下一個刪除的數字就是\((1,4)\),我們的預處理就是要找出這個\(4\)

這個預處理運用的是一個逆序數模型和二分模型,對於每一個刪除的數字,要通過樹狀數組去預處理,確認這個數字真實的刪除位置,維護這個資訊需要將樹狀數組全部初始化成\(1\),查找需要刪除的數字時,就二分查找值等於題目指定刪除數字的列數,刪除後\(Add(x, -1)\)

然後將問題按照問的順序重新恢復,在進行模擬,對於真實位置小於列數的情況直接刪除就可以,等於列數的情況就直接在最後一列上動手腳,大於列數的情況就額外開一個Vector數組去存每行每列額外的數字,這個最後一列刪除的情況也是運用二分+樹狀數組去處理實際刪除的是第幾行 然後就AC了

下面是AC程式碼

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ctime>
#include <iostream>
#include <vector>
#define int long long
#define lowbit(a) a&(-a)
using namespace std;
const int maxn = 1e5+5;
int N, M, Q, Tr1[600005], Tr2[600005], last[600005];
vector<int> Extra[300005];
struct info{
    int x, y, pos;
    bool vis;
};
struct info quiry[300005];
void Tr1add(int x, int val){
    while(x <= 600005){
        Tr1[x] += val;
        x += lowbit(x);
    }
}
int Tr1qui(int x){
    int ans = 0;
    while(x){
        ans += Tr1[x];
        x -= lowbit(x);
    }
    return ans;
}
void Tr2add(int x, int val){
    while(x <= 600005){
        Tr2[x] += val;
        x += lowbit(x);
    }
}
int Tr2qui(int x){
    int ans = 0;
    while(x){
        ans += Tr2[x];
        x -= lowbit(x);
    }
    return ans;
}
bool cmp1(info a, info b){
    if(a.x == b.x) return a.pos < b.pos;
    else return a.x < b.x;
}
bool cmp2(info a, info b){
    return a.pos < b.pos;
}
signed main(){
    clock_t c1 = clock();
#ifdef LOCAL
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
#endif
    scanf("%lld%lld%lld", &N, &M, &Q);
    for(int i = 1; i <= Q; i++){
        scanf("%lld%lld", &quiry[i].x, &quiry[i].y);
        quiry[i].pos = i;
    }
    for(int i = 1; i <= 600005; i++) Tr1add(i, 1), Tr2add(i, 1);
    for(int i = 1; i <= N; i++){
        last[i] = i*M;
    }
    sort(quiry+1, quiry+1+Q, cmp1);
    int Now = 1;
    for(int i = 1; i <= Q; i++){
        int Len = M;
        while(Now < Q && quiry[i].x == quiry[Now+1].x) Now++;
        vector<int> mid;
        for(int j = i; j <= Now; j++){
            if(quiry[j].y == M){ 
                quiry[j].vis = true;
                continue;
            }
            int L = 1, R = Len;
            while(R > L){
                int MID = (R+L) / 2;
                if(Tr1qui(MID) >= quiry[j].y) R = MID;
                else L = MID+1;
            }
            mid.push_back(L); Tr1add(L, -1); Len++; quiry[j].y = L;
        }
        for(int j = 0; j < mid.size(); j++) Tr1add(mid[j], 1);
        i = Now;
    }
    sort(quiry+1, quiry+1+Q, cmp2);
    int Len = N;
    for(int i = 1; i <= Q; i++){
        int ans;
        if(!quiry[i].vis){
            if(quiry[i].y < M) ans = (quiry[i].x-1)*M+quiry[i].y, printf("%lld\n", ans);
            else{
                quiry[i].y -= M; ans = Extra[quiry[i].x][quiry[i].y];
                printf("%lld\n", ans);
            }
        }
        int L = 1, R = Len;
        while(R > L){
            int MID = (R+L) / 2;
            if(Tr2qui(MID) >= quiry[i].x) R = MID;
            else L = MID+1;
        }
        Tr2add(L, -1);
        if(quiry[i].vis) ans = last[L], printf("%lld\n", ans);
        else Extra[quiry[i].x].push_back(last[L]);
        last[++Len] = ans;
    }
end:
    cerr << "Time Used: " << clock() - c1 << "ms" << endl;
    return 0;
}