問題描述
項目 | BIT (Fenwick Tree) | Segment Tree |
---|---|---|
支援的操作類型 | 加法(或 XOR)類型,具有可逆性 | 幾乎所有結合性操作(sum, min, max, gcd, 結構等) |
常見用途 | 區間和查詢、逆序對、樹狀結構統計 | 區間和/最值查詢、最大子段和、Kth number、RMQ 等 |
單點修改 | $O(\log n)$ | $O(\log n)$ |
區間查詢 | $O(\log n)$ | $O(\log n)$ |
區間修改(加法) | 需要技巧(差分 or 雙 BIT) | 搭配 lazy propagation |
區間設值(assign) | 不支援 | 需 lazy 標記,支援設值操作 |
記憶體使用量 | 較小 ($O(n)$) | 較大 ($O(4n)$) |
基本實作細節
- 以區間最大值為例
- 採左閉右閉
- 左半區間是 $[L, M]$;右半區間是 $[M+1, R]$
- 第一顆節點索引是 1 (1-based)
- 當前 index 為 $i$,則左子為 $2i$,右子為 $2i+1$,我們如此定義:
#define lc 2 * id
#define rc 2 * id + 1
- 子樹任一節點區間必在根節點區間內
- 當 $L==R$,該節點為葉節點
- 若資料有 $N$ 筆,線段樹要開 $4N$
建構線段樹(build)
// mx 為該節點所代表的區間最大值
void build(int L, int R, int id) {
// 葉節點最大值即自己本身
if (L == R) {
mx[id] = arr[L];
return;
}
int M = (L + R) / 2;
//建立左子樹
build(L, M, lc);
//建立右子樹
build(M + 1, R, rc);
//該節點區間最大值就是左子右子區間最大值,取 max
mx[id] = max(mx[lc], mx[rc]);
}
查詢(query)
// [l,r]: 我們要查找的區間
// [L,R]: 該節點所表示的區間
// id: 節點索引
int query(int l, int r, int L, int R, int id) {
// 若該節點區間在欲查找的區間內,我們可以直接拿此節點最大值回傳,之後會再跟其他區間內的最大值比較
if (l <= L && R <= r) {
return mx[id];
}
int M = (L + R) / 2;
// 若欲查找的右區間比該節點 M 小,表示在該節點左子樹
if (r <= M)
return query(l, r, L, M, lc);
// 若欲查找的左區間比該節點 M 大,表示在該節點右子樹
else if (l > M)
return query(l, r, M + 1, R, rc);
// 否則 l <= M && r > M 表示該區間橫跨兩左右子
else
return max(query(l, r, L, M, lc), query(l, r, M + 1, R, rc));
}
Main
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n, k;
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> arr[i];
}
build(1, n, 1); // 從節點 1 開始往下 build
cin >> k;
while (k--) {
int a, b;
cin >> a >> b;
cout << query(a, b, 1, n, 1) << '\n'; // 從節點 1 開始往下 query
}
}
懶人標記實作細節
延遲不必要的更新,等「真的需要用到」才下放,以節省時間與空間。
假設有一個陣列 $arr[1..8]$,要把區間 $[1, 8]$ 全部加上 $10$
- 不使用 Lazy:會從根節點一路遞迴到葉子節點,每個都加上 $10$ → 時間複雜度是 $O(n)$
- 使用 Lazy:只在根節點標記 $tag = 10$,等哪天有人問「第 $3 $個元素是多少?」再慢慢把這個 $+10$ 推下去給左右子節點 → $O(logn)$
這樣,如果沒人查詢中間的值,就完全不浪費時間處理底下的更新!
Node 結構
因為有懶標的關係,需要一個結構儲存 $tag$
struct Node {
int sum = 0;
int tag = 0;
};
vector<Node> seg;
加上懶標(addtag)
因為區間總和是把該區間全加上某值,我們是先把他加在該節點上,所以要 $tag * length$,之後有要用到再推下去
void addtag(int tag, int id, int L, int R) {
seg[id].sum += tag * (R - L + 1);
seg[id].tag += tag;
}
下推懶標(push)
把屬於該(左或右)子樹的總和往下推,所以是把加上 $tag * (left/right)child$ $length$
void push(int id, int L, int R) {
int M = (L + R) / 2;
addtag(seg[id].tag, lc, L, M);
addtag(seg[id].tag, rc, M + 1, R);
seg[id].tag = 0;
}
建構線段樹(build)- 區間總和版
原本是取 $max$,但現在要加起來。
void build(int L, int R, int id) {
if (L == R) {
seg[id].sum = arr[L];
return;
}
int M = (L + R) / 2;
build(L, M, lc);
build(M + 1, R, rc);
seg[id].sum = seg[lc].sum + seg[rc].sum;
}
區間修改(modify)
- $r < L $ 或 $R < l$ 欲修改區間不在此節點區間內:不修改
- $l <= L $ 且 $R <= r$ 欲修改區間完全在此節點內:加上懶標,要用再推就好
- 欲修改區間與此節點區間部分重疊:先把 tag 往下推更新子樹,接著 modify 左右子
void modify(int l, int r, int v, int L, int R, int id) {
if (r < L || R < l) return;
if (l <= L && R <= r) {
addtag(v, id, L, R);
return;
}
push(id, L, R);
int M = (L + R) / 2;
modify(l, r, v, L, M, lc);
modify(l, r, v, M + 1, R, rc);
seg[id].sum = seg[lc].sum + seg[rc].sum;
}
查詢(query)- 區間總和版
- $r < L $ 或 $ R < l$ 欲查找區間不在此節點區間內:加 0
- $ l <= L $ 且 $R <= r $ 欲查找區間完全在此節點內:加上該節點總和
- 欲查找區間與此節點區間部分重疊:先把 tag 往下推更新子樹,接著 query 左右子,回傳左右子樹總和相加的值
int query(int l, int r, int L, int R, int id) {
if (r < L || R < l) return 0;
if (l <= L && R <= r) return seg[id].sum;
push(id, L, R);
int M = (L + R) / 2;
return query(l, r, L, M, lc) + query(l, r, M + 1, R, rc);
}
練習題
d539. 區間 MAX
- 跟上述範例程式碼幾乎一模一樣,所以不附上程式碼
- 可以看用 Sparse table 解的版本
d799. 區間求和
- 跟區間總和版程式碼幾乎一模一樣
完整程式碼
#include <bits/stdc++.h>
using namespace std;
#define lc 2 * id
#define rc 2 * id + 1
const int MAXN = 5e5 + 5;
long long arr[MAXN];
struct Node {
long long sum = 0;
long long tag = 0;
};
vector<Node> seg;
void addtag(long long tag, long long id, long long L, long long R) {
seg[id].sum += tag * (R - L + 1);
seg[id].tag += tag;
}
void push(long long id, long long L, long long R) {
long long M = (L + R) / 2;
addtag(seg[id].tag, lc, L, M);
addtag(seg[id].tag, rc, M + 1, R);
seg[id].tag = 0;
}
void build(long long L, long long R, long long id) {
if (L == R) {
seg[id].sum = arr[L];
return;
}
long long M = (L + R) / 2;
build(L, M, lc);
build(M + 1, R, rc);
seg[id].sum = seg[lc].sum + seg[rc].sum;
}
void modify(long long l, long long r, long long v, long long L, long long R, long long id) {
if (r < L || R < l) return;
if (l <= L && R <= r) {
addtag(v, id, L, R);
return;
}
push(id, L, R);
long long M = (L + R) / 2;
modify(l, r, v, L, M, lc);
modify(l, r, v, M + 1, R, rc);
seg[id].sum = seg[lc].sum + seg[rc].sum;
}
long long query(long long l, long long r, long long L, long long R, long long id) {
if (r < L || R < l) return 0;
if (l <= L && R <= r) return seg[id].sum;
push(id, L, R);
long long M = (L + R) / 2;
return query(l, r, L, M, lc) + query(l, r, M + 1, R, rc);
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
long long n, k;
cin >> n;
for (long long i = 1; i <= n; i++) {
cin >> arr[i];
}
seg.resize(MAXN * 4);
build(1, n, 1);
cin >> k;
while (k--) {
int v;
cin >> v;
if (v == 1) {
long long a, b, T;
cin >> a >> b >> T;
modify(a, b, T, 1, n, 1);
} else {
long long a, b;
cin >> a >> b;
cout << query(a, b, 1, n, 1) << '\n';
}
}
}