問題描述

項目BIT (Fenwick Tree)Segment Tree
支援的操作類型加法(或 XOR)類型,具有可逆性幾乎所有結合性操作(sum, min, max, gcd, 結構等)
常見用途區間和查詢、逆序對、樹狀結構統計區間和/最值查詢、最大子段和、Kth number、RMQ 等
單點修改$O(\log n)$$O(\log n)$
區間查詢$O(\log n)$$O(\log n)$
區間修改(加法)需要技巧(差分 or 雙 BIT)搭配 lazy propagation
區間設值(assign)不支援需 lazy 標記,支援設值操作
記憶體使用量較小 ($O(n)$)較大 ($O(4n)$)

基本實作細節

  • 以區間最大值為例
  • 採左閉右閉
  • 左半區間是 $[L, M]$;右半區間是 $[M+1, R]$
  • 第一顆節點索引是 1 (1-based)
  • 當前 index 為 $i$,則左子為 $2i$,右子為 $2i+1$,我們如此定義:
#define lc 2 * id
#define rc 2 * id + 1
  • 子樹任一節點區間必在根節點區間內
  • 當 $L==R$,該節點為葉節點
  • 若資料有 $N$ 筆,線段樹要開 $4N$

建構線段樹(build)

// mx 為該節點所代表的區間最大值
void build(int L, int R, int id) {
    // 葉節點最大值即自己本身
    if (L == R) {
        mx[id] = arr[L];
        return;
    }
    int M = (L + R) / 2;
    //建立左子樹
    build(L, M, lc);
    //建立右子樹
    build(M + 1, R, rc);
    //該節點區間最大值就是左子右子區間最大值,取 max
    mx[id] = max(mx[lc], mx[rc]);
}

查詢(query)

// [l,r]: 我們要查找的區間
// [L,R]: 該節點所表示的區間
// id: 節點索引
int query(int l, int r, int L, int R, int id) {
    // 若該節點區間在欲查找的區間內,我們可以直接拿此節點最大值回傳,之後會再跟其他區間內的最大值比較
    if (l <= L && R <= r) {
        return mx[id];
    }
    int M = (L + R) / 2;
    // 若欲查找的右區間比該節點 M 小,表示在該節點左子樹
    if (r <= M)
        return query(l, r, L, M, lc);
    // 若欲查找的左區間比該節點 M 大,表示在該節點右子樹
    else if (l > M)
        return query(l, r, M + 1, R, rc);
    // 否則 l <= M && r > M 表示該區間橫跨兩左右子
    else
        return max(query(l, r, L, M, lc), query(l, r, M + 1, R, rc));
}

Main

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int n, k;
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> arr[i];
    }
    build(1, n, 1); // 從節點 1 開始往下 build
    cin >> k;
    while (k--) {
        int a, b;
        cin >> a >> b;
        cout << query(a, b, 1, n, 1) << '\n'; // 從節點 1 開始往下 query
    }
}

懶人標記實作細節

延遲不必要的更新,等「真的需要用到」才下放,以節省時間與空間。

假設有一個陣列 $arr[1..8]$,要把區間 $[1, 8]$ 全部加上 $10$

  • 不使用 Lazy:會從根節點一路遞迴到葉子節點,每個都加上 $10$ → 時間複雜度是 $O(n)$
  • 使用 Lazy:只在根節點標記 $tag = 10$,等哪天有人問「第 $3 $個元素是多少?」再慢慢把這個 $+10$ 推下去給左右子節點 → $O(logn)$

這樣,如果沒人查詢中間的值,就完全不浪費時間處理底下的更新!

Node 結構

因為有懶標的關係,需要一個結構儲存 $tag$

struct Node {
    int sum = 0;
    int tag = 0;
};

vector<Node> seg;

加上懶標(addtag)

因為區間總和是把該區間全加上某值,我們是先把他加在該節點上,所以要 $tag * length$,之後有要用到再推下去

void addtag(int tag, int id, int L, int R) {
    seg[id].sum += tag * (R - L + 1);
    seg[id].tag += tag;
}

下推懶標(push)

把屬於該(左或右)子樹的總和往下推,所以是把加上 $tag * (left/right)child$ $length$

void push(int id, int L, int R) {
    int M = (L + R) / 2;
    addtag(seg[id].tag, lc, L, M);
    addtag(seg[id].tag, rc, M + 1, R);
    seg[id].tag = 0;
}

建構線段樹(build)- 區間總和版

原本是取 $max$,但現在要加起來。

void build(int L, int R, int id) {
    if (L == R) {
        seg[id].sum = arr[L];
        return;
    }
    int M = (L + R) / 2;
    build(L, M, lc);
    build(M + 1, R, rc);
    seg[id].sum = seg[lc].sum + seg[rc].sum;
}

區間修改(modify)

  • $r < L $ 或 $R < l$ 欲修改區間不在此節點區間內:不修改
  • $l <= L $ 且 $R <= r$ 欲修改區間完全在此節點內:加上懶標,要用再推就好
  • 欲修改區間與此節點區間部分重疊:先把 tag 往下推更新子樹,接著 modify 左右子
void modify(int l, int r, int v, int L, int R, int id) {
    if (r < L || R < l) return;
    if (l <= L && R <= r) {
        addtag(v, id, L, R);
        return;
    }
    push(id, L, R);
    int M = (L + R) / 2;
    modify(l, r, v, L, M, lc);
    modify(l, r, v, M + 1, R, rc);
    seg[id].sum = seg[lc].sum + seg[rc].sum;
}

查詢(query)- 區間總和版

  • $r < L $ 或 $ R < l$ 欲查找區間不在此節點區間內:加 0
  • $ l <= L $ 且 $R <= r $ 欲查找區間完全在此節點內:加上該節點總和
  • 欲查找區間與此節點區間部分重疊:先把 tag 往下推更新子樹,接著 query 左右子,回傳左右子樹總和相加的值
int query(int l, int r, int L, int R, int id) {
    if (r < L || R < l) return 0;
    if (l <= L && R <= r) return seg[id].sum;
    push(id, L, R);
    int M = (L + R) / 2;
    return query(l, r, L, M, lc) + query(l, r, M + 1, R, rc);
}

練習題

d539. 區間 MAX

連結:https://zerojudge.tw/ShowProblem?problemid=d539

  • 跟上述範例程式碼幾乎一模一樣,所以不附上程式碼
  • 可以看用 Sparse table 解的版本

d799. 區間求和

連結:https://zerojudge.tw/ShowProblem?problemid=d799

  • 跟區間總和版程式碼幾乎一模一樣

完整程式碼

#include <bits/stdc++.h>
using namespace std;
#define lc 2 * id
#define rc 2 * id + 1
const int MAXN = 5e5 + 5;

long long arr[MAXN];

struct Node {
    long long sum = 0;
    long long tag = 0; 
};
vector<Node> seg;

void addtag(long long tag, long long id, long long L, long long R) {
    seg[id].sum += tag * (R - L + 1);
    seg[id].tag += tag;
}

void push(long long id, long long L, long long R) {
    long long M = (L + R) / 2;
    addtag(seg[id].tag, lc, L, M);
    addtag(seg[id].tag, rc, M + 1, R);
    seg[id].tag = 0;
}

void build(long long L, long long R, long long id) {
    if (L == R) {
        seg[id].sum = arr[L];
        return;
    }
    long long M = (L + R) / 2;
    build(L, M, lc);
    build(M + 1, R, rc);
    seg[id].sum = seg[lc].sum + seg[rc].sum;
}

void modify(long long l, long long r, long long v, long long L, long long R, long long id) {
    if (r < L || R < l) return;
    if (l <= L && R <= r) {
        addtag(v, id, L, R);
        return;
    }
    push(id, L, R);
    long long M = (L + R) / 2;
    modify(l, r, v, L, M, lc);
    modify(l, r, v, M + 1, R, rc);
    seg[id].sum = seg[lc].sum + seg[rc].sum;
}

long long query(long long l, long long r, long long L, long long R, long long id) {
    if (r < L || R < l) return 0;
    if (l <= L && R <= r) return seg[id].sum;
    push(id, L, R);
    long long M = (L + R) / 2;
    return query(l, r, L, M, lc) + query(l, r, M + 1, R, rc);
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    long long n, k;
    cin >> n;
    for (long long i = 1; i <= n; i++) {
        cin >> arr[i];
    }
    seg.resize(MAXN * 4);
    build(1, n, 1);
    cin >> k;
    while (k--) {
        int v;
        cin >> v;
        if (v == 1) {
            long long a, b, T;
            cin >> a >> b >> T;
            modify(a, b, T, 1, n, 1);
        } else {
            long long a, b;
            cin >> a >> b;
            cout << query(a, b, 1, n, 1) << '\n';
        }
    }
}