树状数组

树状数组

1. 绪论

树状数组本质上是一个 运用了 分块思想前缀和 数组,使得查询和修改的时间复杂度都是\(O(logN)\) 级别,但由于是一个前缀和数组,所以对于一些区间能做的事情还是十分有限,鉴于树状数组的实现简单,代码量少,对于问题是否使用树状数组还是线段树的情况需要自行辨别

2. 模板

#define lowbit(a) a&(-a)
int Trarr[100005];
void add(int x, int val){
    while(x <= MAX_LIMIT){
        Trarr[x] += val;
        x += lowbit(x);
    }
}
int getsum(int x){
    int ans = 0;
    while(x){
        ans += Trarr[x];
        x -= lowbit(x);
    }
    return ans;
}

3. 经典模型

  • PUIQ问题 (单点修改及区间求和) \(OptA\) . \(a_i += P\) \(OptB\) . \(\sum^i_{[L,R]} a_i\)

    考虑点 \(i\) 的值发生变化时会影响到树状数组的哪些部分,然后进行更新就可以了,一般是在点 \(i\) 加一个负值去消除原数的影响再加入新数,并更新其前继数组,区间求和就和前缀和一样,直接 \(getsum(R) – getsum(L-1)\)

    见 Luogu P3374

  • IUPQ问题 (区间修改和单点查询) \(Opt\) \(a_i += P,\quad i\in[L, R]\) \(Quiry\quad a_[i]\)

    考虑结合差分算法,区间修改是差分的算法优势,对于单点求值就是差分的前缀和,此时用树状数组来维护差分的前缀和数组就可以解决,对于区间修改直接\(Add(L, V)\quad Add(R+1, -V)\) 单点求和就相当于查前缀和数组 不再赘述

    见 Luogu P3368

  • 求逆序数问题 \(Count \quad i < j , a_i > a_j\) 如 $1,5,2,4,3 $ 中的\((5,2)\)是一对逆序对

    核心思想是记录 在树状数组中记录值为 \([1,a_i]\) 的个数注意不是单纯的将数据的值作为树状数组的下标,将树状数组当一个桶用,树状数组本质是一个前缀和数组,这样是行不通的,此时的\(Add(a_i, 1)\) 操作就是记录值在\([1,a_i]\) 的数又多一个, 计算逆序数时就一边遍历,一边将当前的数计入树状数组,然后询问比自己小的数有多少个 即询问值在\([1, a_i]\)的数有多少个,此时得到的是 \(i < j, ai < aj\) 的数,再用当前的 \(i\) 减去这个值就能得到这个位置的逆序数

    见 Luogu P1966

    统计个数的思想十分常用,有时也将树状数组数组全初始化为1或0,来查找第k个此时的值,结合下面的二分

  • 二分思想 就是在查找的过程中进行二分枚举

    比如你每次可以往容器中丢入一个编号为 \(i\) 的球,或者询问容器中编号第 \(K\) 大的球的编号,显然对于编号第 \(K\) 的球,前面只能有 \(K-1\) 个球比他大, 那就有\(getsum(N) – getsum(x) <= K-1\) ,同时由于前缀和的特殊性,树状数组中可能存在多个该情况的数字,此时取第一个数字,由于前缀和必然满足单调性,所以此时可以使用二分来查找满足这个条件的数字

  • 区间排序思想 有时为了维护区间或者比较多个区间内的信息,会固定一端进行排序

    如给定 \(N\) 个区间,计算第 \(i\) 个区间有多少个区间大于他,此时可以将给定的区间按右端点从大到小排序,右端点相同的则左端点小的排前面,用树状数组维护左端点的插值,即读入第 \(i\) 个区间时,就进行 \(Add(L_i, 1)\),表示\([1,L_i]\) 这一段中又插入了一个区间,查询时直接询问这一段有多少个区间就可以,因为我们进行了排序,所以保证右端点是递减的。

    还有如给定长度为 \(N\) 的一个序列,有 \(Q\) 个询问对于某段区间内的数字有多少种,此时也可进行区间排序,然后一边遍历一边维护,考虑对于第二次出现的数字,在树状数组中消除他在第一次出现的影响并在第二次出现的位置更新它,即我们此时在树状数组中维护的信息是 \([1, R]\)中有多少种数,并且保证了出现的数字不重复计算,且始终将其更新到最右端 。

    最后计算答案时只要询问 \(getsum(quiry[i].R) – getsum(quiry[i].L-1)\)

    见 Luogu P4113

  • 多维树状数组 有时一维的树状数组无法维护全部信息,此时选择多维记录状态

    如一个 \(N*M\) 的方格,每个格子初始时有一个整数权值,接下来有两种操作:

    改变一个格子里权值或者求一个子矩阵中特定权值出现的个数,此时我们选择开一个三维的树状数组,前两维表示位置,第三维表示数字(类逆序数模型),维护时就用两个for循环去控制位置,然后再对应数字位置上\(+val\) 如下面代码

    #define lowbit(a) a&(-a)
    void Tradd(int x, int y, int val, int color){
        for(int i = x; i <= N; i += lowbit(i)){
            for(int j = y; j <= M; j += lowbit(j)){
                Trarr[i][j][color] += val;
            }
        }
    }
    int Trqui(int x, int y, int color){
        int ans = 0;
        for(int i = x; i; i -= lowbit(i)){
            for(int j = y; j; j -= lowbit(j)){
                ans += Trarr[i][j][color];
            }
        }
        return ans;
    }
    

询问的时候就类似二维前缀和

\(Trqui(x2, y2, color) – Trqui(x1-1, y2, color) – Trqui(x2, y1-1, color) + Trqui(x1-1, y1-1, color)\)

不明白的读者可以尝试画一个图就能明白

这题见Luogu P4054

4. 大应用题

这里提供一道比较麻烦的例题 Luogu P3960

大概的一个思路每行单独维护,因为互不影响,最后一列比较特殊也需要单独维护,然后离线问题,将问题按照行从小到大排列,然后预处理每列的真实删除位置,比如对于\((1,3)\) 是删除第一行第三列数字,第二次删除时就不再是\((1,3)\) 了,考虑其先向左对齐,所以下一个删除的数字就是\((1,4)\),我们的预处理就是要找出这个\(4\)

这个预处理运用的是一个逆序数模型和二分模型,对于每一个删除的数字,要通过树状数组去预处理,确认这个数字真实的删除位置,维护这个信息需要将树状数组全部初始化成\(1\),查找需要删除的数字时,就二分查找值等于题目指定删除数字的列数,删除后\(Add(x, -1)\)

然后将问题按照问的顺序重新恢复,在进行模拟,对于真实位置小于列数的情况直接删除就可以,等于列数的情况就直接在最后一列上动手脚,大于列数的情况就额外开一个Vector数组去存每行每列额外的数字,这个最后一列删除的情况也是运用二分+树状数组去处理实际删除的是第几行 然后就AC了

下面是AC代码

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ctime>
#include <iostream>
#include <vector>
#define int long long
#define lowbit(a) a&(-a)
using namespace std;
const int maxn = 1e5+5;
int N, M, Q, Tr1[600005], Tr2[600005], last[600005];
vector<int> Extra[300005];
struct info{
    int x, y, pos;
    bool vis;
};
struct info quiry[300005];
void Tr1add(int x, int val){
    while(x <= 600005){
        Tr1[x] += val;
        x += lowbit(x);
    }
}
int Tr1qui(int x){
    int ans = 0;
    while(x){
        ans += Tr1[x];
        x -= lowbit(x);
    }
    return ans;
}
void Tr2add(int x, int val){
    while(x <= 600005){
        Tr2[x] += val;
        x += lowbit(x);
    }
}
int Tr2qui(int x){
    int ans = 0;
    while(x){
        ans += Tr2[x];
        x -= lowbit(x);
    }
    return ans;
}
bool cmp1(info a, info b){
    if(a.x == b.x) return a.pos < b.pos;
    else return a.x < b.x;
}
bool cmp2(info a, info b){
    return a.pos < b.pos;
}
signed main(){
    clock_t c1 = clock();
#ifdef LOCAL
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
#endif
    scanf("%lld%lld%lld", &N, &M, &Q);
    for(int i = 1; i <= Q; i++){
        scanf("%lld%lld", &quiry[i].x, &quiry[i].y);
        quiry[i].pos = i;
    }
    for(int i = 1; i <= 600005; i++) Tr1add(i, 1), Tr2add(i, 1);
    for(int i = 1; i <= N; i++){
        last[i] = i*M;
    }
    sort(quiry+1, quiry+1+Q, cmp1);
    int Now = 1;
    for(int i = 1; i <= Q; i++){
        int Len = M;
        while(Now < Q && quiry[i].x == quiry[Now+1].x) Now++;
        vector<int> mid;
        for(int j = i; j <= Now; j++){
            if(quiry[j].y == M){ 
                quiry[j].vis = true;
                continue;
            }
            int L = 1, R = Len;
            while(R > L){
                int MID = (R+L) / 2;
                if(Tr1qui(MID) >= quiry[j].y) R = MID;
                else L = MID+1;
            }
            mid.push_back(L); Tr1add(L, -1); Len++; quiry[j].y = L;
        }
        for(int j = 0; j < mid.size(); j++) Tr1add(mid[j], 1);
        i = Now;
    }
    sort(quiry+1, quiry+1+Q, cmp2);
    int Len = N;
    for(int i = 1; i <= Q; i++){
        int ans;
        if(!quiry[i].vis){
            if(quiry[i].y < M) ans = (quiry[i].x-1)*M+quiry[i].y, printf("%lld\n", ans);
            else{
                quiry[i].y -= M; ans = Extra[quiry[i].x][quiry[i].y];
                printf("%lld\n", ans);
            }
        }
        int L = 1, R = Len;
        while(R > L){
            int MID = (R+L) / 2;
            if(Tr2qui(MID) >= quiry[i].x) R = MID;
            else L = MID+1;
        }
        Tr2add(L, -1);
        if(quiry[i].vis) ans = last[L], printf("%lld\n", ans);
        else Extra[quiry[i].x].push_back(last[L]);
        last[++Len] = ans;
    }
end:
    cerr << "Time Used: " << clock() - c1 << "ms" << endl;
    return 0;
}