問題描述

計算前綴和後陣列後,一旦更改某個值,則包含該值的前綴和都必須跟著改動。假設我們改動第 $2$ 格,則以後的前綴和都要改動,時間複雜度為 $O(n)$。因此我們試圖讓包含該值的區間可以減少,這樣就可以更動少一點資料。

Least Significant Bit(LSB)

在這裡的定義是最後面的 $1$ 所代表的數字

LSB 可以用以下方式計算:

  • $x$ & $(-x)$

例如:

  • $6$ 的二進位是 $110$,LSB = $2$
  • $12$ 的二進位是 $1100$,LSB = $4$
  • $24$ 的二進位是 $11000$,LSB = $8$

LSB 在 BIT 中扮演重要角色:

  1. 用來判斷每個節點負責的區間長度
  2. 用來找父節點和子節點

Binary Indexed Tree(BIT) 原理

BIT 利用二進位的特性,將每個位置負責不同長度的區間和。

這邊的 Binary 不是 Binary Tree,而是二進制的意思

例如 index = $6$ (二進位 $110$),則負責區間 $[5,6]$ 的和

  • 因為 $6$ 的 LSB 是 $2$,所以往前負責 $2$ 個數字

圖解 BIT 結構

BIT 結構 圖片來源:Ruby Ku 的投影片

每個節點負責的區間長度就是它的 LSB,範圍為 $[i-LSB(i)+1, i]$。 例如:

  • 節點 $8$ 的 LSB = $8$,負責 $[1,8]$
  • 節點 $12$ 的 LSB = $4$,負責 $[9,12]$
  • 節點 $14$ 的 LSB = $2$,負責 $[13,14]$
  • 節點 $15$ 的 LSB = $1$,負責 $[15,15]$

Indexed tree 圖片來源:WIWIHO 的競程筆記

觀察此圖,我們可以發現:

  1. 節點 $0$ 是根節點
  2. 節點 $i$ 的父節點是 $i - LSB(i)$
  3. 節點 $i$ 的右兄弟節點是 $i + LSB(i)$。
  4. 若已經沒有右兄弟節點的話,$i + LSB(i)$ 會是 i 節點的父節點的右兄弟節點。可觀察第 $7$ 節點和第 $8$ 節點。
  5. 節點 $i$ 的深度與 $i$ 是 $1$ 的位元數相同
  6. 節點 $i$ 的父節點是 $p$,則節點 $i$ 的區間是 $[p+1, i]$,因為 $p = i - LSB(i)$
  7. 節點 $i$ 的區間包含它所有左兄弟節點的區間,因為它們的父節點都一樣,所以它們區間的起點都一樣,但 $i$ 的結束點較後面
  8. 節點 $i$ 的區間包含所有左兄弟節點 $j$ 的子孫節點的區間,因為 $j$ 子孫節點區間必在 $j$ 和 $i-1$ 之間,而 $i$ 的區間肯定包含這段。

實作細節

  1. update($x$, $val$):
  • 更新 $x$ 位置的值,會影響到所有包含 $x$ 的區間
  • 每次找右兄弟節點:$x$ += lowbit($x$)(第$3$點)
  • 因為 x 的父節點的右兄弟也包含 x(第 $8$ 點),所以我要得到父節點的右節點,也就是最右兄弟節點為 $t$,其父節點的右兄弟就是 $t + LSB(t)$(第 $7$ 點)
void update(int x, int d){ // 把 x 節點的值改成 d
    while(x <= N){
        b[x] += d;
        x += x & (-x);
    }
}

可以發現 $LSB(x)$ 會不斷增加,所以複雜度是 $O(log n)$

  1. query($x$):
  • 查詢以 $x$ 為結尾的的區間和,就要找到幾段不會重疊聯集為所求前綴的區間
  • 區間以 $X$ 為結尾的節點是 ,而節點 $x$ 的區間剛好緊接在它的父節點之後,它的父節點是 $x - LSB(x)$,所以只要找到 $x$ 和它所有祖先節點,這些區間聯集起來就是我們想要的前綴
  • 每次往上找父節點:$x$ -= lowbit($x$)
int query(int x){
    int ret = 0;
    while (x){
        ret += b[x];
        x -= x & (-x);
    }
    return ret;
}

同樣地, $LSB(x)$ 會不斷增加,因此複雜度是 $O(log n)$

  1. 建構 tree 時,如果有初始值的話,就把每一個元素分別 update 就好了,複雜度是 O(nlogn)

程式碼實作

#include <iostream>
#include <vector>
using namespace std;

const int N = 100005;  // 根據需求調整大小
int bit[N];            // Binary Indexed Tree,1-based index

// 單點加值:將第 i 個位置加上 val
void update(int i, int val) {
    while (i < N) {
        bit[i] += val;
        i += i & -i;
    }
}

// 前綴和查詢:回傳前 i 個元素的總和
int query(int i) {
    int res = 0;
    while (i > 0) {
        res += bit[i];
        i -= i & -i;
    }
    return res;
}

// 查詢區間 [l, r] 的總和
int range_query(int l, int r) {
    return query(r) - query(l - 1);
}

// 建表:從原始陣列 a[1..n] 建立 BIT
void build(const vector<int>& a, int n) {
    for (int i = 1; i <= n; ++i) {
        update(i, a[i]);  // a 是從 1 開始的
    }
}

int main() {
    int n = 5;
    vector<int> a(n + 1);  // a[1..n]
    a[1] = 3; a[2] = 2; a[3] = -1; a[4] = 6; a[5] = 5;

    build(a, n);  // 建立 BIT

    cout << query(5) << "\n";          // 輸出 a[1] + ... + a[5]
    cout << range_query(2, 4) << "\n"; // 輸出 a[2] + a[3] + a[4]

    update(3, 4);                      // a[3] += 4
    cout << query(5) << "\n";          // 查詢更新後的總和

    return 0;
}

離散化 Discretization

處理元素值範圍太大或值不是連續整數(可能是浮點數或無序整數) 時常用的技巧。配合 BIT 就可以在這些情況下仍有效率地進行區間查詢。

假設要用 BIT 統計一個陣列中,小於某個值的元素有幾個。如果這些值範圍很大,例如 $10^9$,你就不能直接開一個長度 $10^9$ 的陣列來當 BIT,那會爆記憶體,所以我們要把這些值 壓縮成連續整數編號(例如 1、2、3…),來當作 BIT 的索引。

原始陣列:

vector<int> a = {100, 5000, 3, 100, 3};
  1. 提取所有值並排序
vector<int> vals = a;
sort(vals.begin(), vals.end());
vals.erase(unique(vals.begin(), vals.end()), vals.end());  // 移除重複值
  1. 建立「值 → 編號」的對應表
unordered_map<int, int> mp;
for (int i = 0; i < vals.size(); ++i) {
    mp[vals[i]] = i + 1;  // 使用 1-based index
}
  1. 用對應值更新 BIT
for (int x : a) {
    int idx = mp[x];  // 取得離散化後的 index
    update(idx, 1);   // 假設你要統計出現次數
}
  1. 查詢小於某個值的個數(假設查小於 5000 的)
int idx = mp[5000];     // 5000 的 index
int res = query(idx - 1);  // 查詢比它小的值的總出現次數

練習題

Zerojudge - d794. 世界排名

連結:https://zerojudge.tw/ShowProblem?problemid=d794

本題 AC code

  • 3.5s, 22.4MB
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 100005;
int bit[N]; // Binary Indexed Tree,1-based

// 更新:將第 i 位加上 val
int update(int i, int val) {
    while (i < N) {
        bit[i] += val;
        i += i & -i;
    }
}

// 查詢:回傳前 i 位總和
int query(int i) {
    int res = 0;
    while (i > 0) {
        res += bit[i];
        i -= i & -i;
    }
    return res;
}

int main() {
    int n;
    while (cin >> n) {
        for (int i = 1; i <= N; i++) {
            bit[i] = 0; // clear BIT
        }
        vector<ll> a(n);
        for (auto& x : a) {
            cin >> x;
        }
        // Step 1: 離散化
        vector<ll> vals = a;
        sort(vals.begin(), vals.end());
        vals.erase(unique(vals.begin(), vals.end()), vals.end());

        unordered_map<ll, ll> mp; // 值 -> 離散後的 index(1-based)
        for (int i = 0; i < vals.size(); ++i) {
            mp[vals[i]] = i + 1;
        }

        // Step 2: 統計每個位置左邊比它小的數
        for (int i = 0; i < n; ++i) {
            int idx = mp[a[i]];
            cout << i + 1 - query(idx - 1) << "\n"; // 查詢比當前值小的數有幾個
            update(idx, 1);                         // 插入當前值進 BIT 中
        }
    }
    return 0;
}

CSES - Nested Ranges Count

連結:https://cses.fi/problemset/task/2169

這題要求我們計算每個區間被多少其他區間包含,以及包含多少其他區間。

對於每個區間,我們需要輸出兩個數字

  • 收集所有 $y$ 座標並離散化,方便作為 BIT 的 index

  • 按照 $x$ 座標排序,如果 $x$ 座標相同,則按照 $y$ 座標降序排序

  • 計算包含的區間數(當前區間為 $[a,b]$ 之前的區間為 $[c,d]$,$[a,b]$ 包含 $[c,d]$ 表示 $a<=c$ 且 $d<=b$)

    • 從右至左處理,確保 $a <= c$
    • 使用 BIT 記錄目前出現過的右界數量
    • query($b$) 查詢目前已有的右界小於等於此區間右界($d<=b$)的總數量
  • 計算被包含的區間數(當前區間為 $[a,b]$ 之前的區間為 $[c,d]$,$[a,b]$ 被 $[c,d]$ 包含表示 $a>=c$ 且 $b<=d$)

    • 從左至右處理,確保 $a >= c$
    • 透過「差分更新」的方法處理 BIT(對整體 $+1$,對 $b+1$ 開始 $-1$) 假設前兩個是 $[2,8]$, $[3,9]$,現在是 $[3,5]$

    離散化後 ${5,8,9}$ -> ${1,2,3}$

    這樣的話 $[2,8]$ 時會將 BIT[$2$] 以前的 BIT 都 $+1$

    BIT[$1$]BIT[$2$]BIT[$3$]
    $1$$1$$0$

    $[3,9]$ 時會將 BIT[$3$] 之前的 BIT 都 $+1$

    BIT[$1$]BIT[$2$]BIT[$3$]
    $2$$2$$1$

    之後查 $[3,5]$ 時因為以確保左界會比之前的大,我們這時候查 query($1$),就可以得到之前有多少區間右界 $>=5$

本題 AC code

#include <bits/stdc++.h>
using namespace std;

#define int long long
#define endl '\n'
#define F first
#define S second

map<int, int> mp;
const int N = 2e5 + 1;
int T[N + 1] = {0};
void update(int n, int x, int val) {
    for (; x <= n; x += x & -x) T[x] += val;
}
int query(int x) {
    int s = 0;
    for (; x > 0; x -= x & -x) s += T[x];
    return s;
}

bool comp(pair<pair<int, int>, int> a, pair<pair<int, int>, int> b) {
    if (a.F.F == b.F.F) return a.F.S > b.F.S;
    return a.F.F < b.F.F;
}

signed main() {
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int n;
    cin >> n;
    set<int> b;
    vector<pair<pair<int, int>, int>> v(n);
    for (int i = 0; i < n; i++) {
        int x, y;
        cin >> x >> y;
        v[i].S = i;
        v[i].F = {x, y};
        b.insert(y);
    }
    int cnt = 0;
    for (auto i : b) {
        mp[i] = ++cnt;
    }
    sort(v.begin(), v.end(), comp);
    update(cnt, mp[v[n - 1].F.S], 1);
    int ans[n] = {0};
    for (int i = n - 2; i >= 0; i--) {
        ans[v[i].S] += query(mp[v[i].F.S]);
        update(cnt, mp[v[i].F.S], 1);
    }

    for (int i = 0; i < n; i++) cout << ans[i] << ' ';

    memset(ans, 0, sizeof ans);
    memset(T, 0, sizeof T);
    update(cnt, 1, 1);
    update(cnt, mp[v[0].F.S] + 1, -1);
    for (int i = 1; i < n; i++) {
        ans[v[i].S] += query(mp[v[i].F.S]);
        update(cnt, 1, 1);
        update(cnt, mp[v[i].F.S] + 1, -1);
    }
    cout << endl;
    for (int i = 0; i < n; i++) cout << ans[i] << ' ';
}