理解線段樹這一篇文章就夠啦!

2023-03-01 18:01:04

線段樹

TODO:

前言

本文中,若無特殊說明,數列下標均從 \(1\) 開始

由於本人實力有限,線段樹更高階的拓展暫不做考慮

引入

什麼是線段樹

線段樹\(Segment\ Tree\))是一種二元搜尋樹,它將一個區間劃分成一些單元區間,每個單元區間對應線段樹中的一個葉子節點,由於每一個節點都表示一個區間(或者說是線段),所以也被認為是一顆區間樹。

用途

線段樹常用於動態維護區間資訊

例題

P3374 【模板】樹狀陣列 1 - 洛谷

題目簡述:對數列進行單點修改以及區間求和

常規解法

單點修改的時間複雜度為 \(O(1)\)

區間求和的時間複雜度為 \(O(n)\)

\(m\) 次操作,則總時間複雜度為 \(O(n\times m)\)

點選檢視程式碼
import java.io.*;

public class Main {
    static StreamTokenizer in = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));

    static int get() throws IOException {
        in.nextToken();
        return (int) in.nval;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = get(), m = get();
        int[] a = new int[n];
        for (int i = 0; i < n; ++i) a[i] = get();
        while (m-- != 0) {
            int command = get(), x = get(), y = get();
            if (command == 1) {
                a[x - 1] += y;
            } else {
                int sum = 0;
                for (int i = x - 1; i < y; i++) sum += a[i];
                out.println(sum);
            }
        }
        out.close();
    }
}

字首和解法

區間求和通過字首和優化,但單點修改的時候需要修改字首和陣列

單點修改的時間複雜度為 \(O(n)\)

區間求和的時間複雜度為 \(O(1)\)

\(m\) 次操作,則總時間複雜度為 \(O(n\times m)\)

點選檢視程式碼
import java.io.*;

public class Main {
    static StreamTokenizer in = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));

    static int get() throws IOException {
        in.nextToken();
        return (int) in.nval;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = get(), m = get();
        int[] sum = new int[n + 1];
        for (int i = 1; i <= n; ++i) sum[i] = sum[i - 1] + get();
        while (m-- != 0) {
            int command = get(), x = get(), y = get();
            if (command == 1) {
                for (int i = x; i <= n; ++i) sum[i] += y;
            } else {
                System.out.println(sum[y] - sum[x - 1]);
            }
        }
        out.close();
    }
}

線段樹解法

線段樹的思想

線段樹是一種基於分治思想的一種資料結構,它通過不斷將區間拆分合併來實現區間改查

線段樹的形態結構

我們規定:若當前節點的標號為 \(x\),則其左兒子標號為 \(2\times x\),右兒子標號為 \(2\times x+1\),葉子節點的管轄區間長度為 \(1\)

一顆管理陣列長度為 \(7\) 的線段樹基本結構如下,其中藍色圓中資料代表節點標號,綠色矩形內資料代表該節點的管轄區間。

線段樹的儲存

對於二元樹的儲存,通常選擇使用指標儲存,但在演演算法競賽中,常選擇堆式儲存(靜態陣列儲存)。

選擇堆式儲存時,我們需要確定陣列空間大小。

一顆管理陣列長度為 \(n\) 的線段樹的節點個數為 \(2\times n -1\)

證明如下:

設一顆線段樹的度數為 \(0\) 的節點個數為 \(N_0\),度數為 \(1\) 的節點個數為 \(N_1\),度數為 \(2\) 的節點個數為 \(N_2\)

由線段樹的定義,葉子節點的管轄區間長度為 \(1\),則葉子節點的個數為 \(n\),即 \(N_0=n\)

每個節點代表一個區間,如果一個區間能劃分,則一定劃分為 \(2\) 個區間,因此 \(N_1=0\)

二元樹的性質:\(N_2=N_0-1\)

因此,一顆管理陣列長度為 \(n\) 的線段樹的節點個數為 \(N=N_0+N_1+N_2=2\times n -1\)

那是否靜態陣列空間就只需要 \(2\times n-1\) 呢?

線段樹的形態結構中的圖表示並非如此。

因為有些節點是空的,所以最後一個節點標號一定與滿二元樹相同。

前置結論:對於高度為 \(h\) 的滿二元樹,最後一層有 \(2^{h-1}\) 個節點,總共有 \(2^h-1\) 個節點,則除最後一層外的節點總數有 \(2^{h}-1-2^{h-1}=2^{h-1}-1\),與最後一層節點個數對比,得:$ 除最後一層外的節點總數 = 最後一層的節點個數 -1 $

線段樹所需要的節點數量,分兩種情況來討論:

  • 如果 \(n\) 恰好是 \(2\)\(k\) 次冪,由於線段樹最後一層的葉子節點儲存的是陣列元素本身,最後一層的節點數就是 \(n\),則前面所有層的節點數為 \(n-1\),那麼總節點數為 \(2\times n -1\)

  • 如果 \(n\) 不是 \(2\)\(k\) 次冪,即 \(n=2^k+x\) 其中 \(x>0\),則需要新開闢一層來儲存,等同於 \(2^{k+1}\) 的情況,則總結點個數為 \(4n-4x-1\),最大不超過 \(4n-5\)

又由於我們讓資料從下標為 \(1\) 開始儲存,得出如下結論:

  • \(n\)\(2\) 的正整數冪時,所需空間大小為 \(2\times n\)
  • \(n\) 不是 \(2\) 的正整數冪時,所需空間大小為 \(4\times n-4\)

為了方便,我們通常選擇開闢 \(4\times n\) 的空間

建樹

以下 \(tree[i]\) 代表 \(i\) 號節點所儲存的資料,\(data[i]\) 代表原陣列資料

每個節點 \(p\) 的左右子節點的編號分別為 \(2p\)\(2p+1\)

要求得某一個節點的值,需要得到兩個子節點的值,再將其合併,採用遞迴的形式建樹,其中合併操作單獨記為一個函數(修改操作時用)

遞迴的終止條件為達到葉子節點,即節點管轄區間長度為 \(1\),不能再劃分了,此時 \(l=r\)

所需函數引數如下:

  1. 當前節點的編號,即 \(tree\) 陣列中的索引 \(o\)
  2. 該節點所管轄區間的左邊界 \(l\)
  3. 該節點所管轄區間的右邊界 \(r\)
// 合併x和y兩個節點的區間值,並賦給o節點
public void pushUp(int o, int x, int y) {
    tree[o] = tree[x] + tree[y];
}

/**
 * @param o     當前節點編號
 * @param l     當前節點管轄區間的左邊界
 * @param r     當前節點管轄區間的右邊界
 * @param data  原陣列資料
 */
public void build(int o, int l, int r, int[] data) {
    // 到達葉子節點(管轄區間長度為1)
    if (l == r) {
        tree[o] = data[l];
        return;
    }
    // mid為中間值,用於劃分割區間
    // x 為左兒子編號
    // y 為右兒子編號
    int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 構建左子樹,區間為[l,mid]
    build(x, l, mid, data);
    // 構建右子樹,區間為[mid+1,r]
    build(y, mid + 1, r, data);
    // 合併兩個子區間的資料
    pushUp(o, x, y);
}

單點修改

修改元素時,需要先找到待修改的最底層的資料(葉子節點),修改後再逐步上傳資料

單點修改的基本步驟如下:

  1. 若待修改元素位於 \([l,mid]\) 區間,則遞迴修改左子樹部分

    若待修改元素位於 \([mid+1,r]\) 區間,則遞迴修改右子樹部分

  2. 合併兩個子區間的資料

所需函數引數如下:

  1. 待修改元素位置 \(index\)
  2. 修改後元素(或增量)的資料 \(val\)
  3. 當前節點的編號,即 \(tree\) 陣列中的索引 \(o\)
  4. 該節點所管轄區間的左邊界 \(l\)
  5. 該節點所管轄區間的右邊界 \(r\)
/**
 * @param index 待修改元素位置
 * @param val   修改後元素(或增量)的資料
 * @param o     當前節點編號
 * @param l     當前節點管轄區間的左邊界
 * @param r     當前節點管轄區間的右邊界
 */
public void updateOne(final int index, final int val, int o, int l, int r) {
    // 到達葉子節點(管轄區間長度為1)
    if (l == r) {
        tree[o] += val;
        return;
    }
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 修改元素在左邊區間
    if (index <= mid) updateOne(index, val, x, l, mid);
    // 修改元素在右邊區間
    else updateOne(index, val, y, mid + 1, r);
    // 合併兩個子區間的資料
    pushUp(o, x, y);
}

單點查詢

與單點修改相同,只是不需要進行子區間資料合併了(因為沒有變)

/**
 * @param index 待查詢元素位置
 * @param o     當前節點編號
 * @param l     當前節點管轄區間的左邊界
 * @param r     當前節點管轄區間的右邊界
 * @return index位置處的值
 */
public int queryOne(final int index, int o, int l, int r) {
    if (l == r) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 如果查詢元素在左邊區間
    if (index <= mid) return queryOne(index, x, l, mid);
    // 否則在右邊區間
    return queryOne(index, y, mid + 1, r);
}

區間查詢

求某一個區間的值,對於線段樹就是分解線段樹區間,直至該區間在查詢區間內部,此時該區間的值已經獲得,不需要再分解了

區間查詢的分解區間步驟如下:

  • 如果左子樹包含查詢區間,即 \(queryLeft\le mid\),則查詢左子樹
  • 如果右子樹包含查詢區間,即 \(queryRight>mid\),則查詢右子樹

所需函數引數如下:

  1. 待查詢區間左邊界 \(left\)
  2. 待查詢區間左邊界 \(right\)
  3. 當前節點的編號,即 \(tree\) 陣列中的索引 \(o\)
  4. 該節點所管轄區間的左邊界 \(l\)
  5. 該節點所管轄區間的右邊界 \(r\)
/**
 * @param left  待查詢區間左邊界
 * @param right 待查詢區間右邊界
 * @param o     當前節點編號
 * @param l     當前節點管轄區間的左邊界
 * @param r     當前節點管轄區間的右邊界
 * @return 區間[left, right]的值
 */
public int queryRange(final int left, final int right, int o, int l, int r) {
    // 如果線段樹區間在查詢區間內部,這一區間已經為答案了,不需要再分解了
    if (left <= l && r <= right) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    int ans = 0;
    // 如果左子樹包含查詢區間
    if (left <= mid) ans += queryRange(left, right, x, l, mid);
    // 如果右子樹包含查詢區間
    if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
    return ans;
}

複雜度分析

空間複雜度為 \(O(4n)=O(n)\)

單點修改、單點查詢、區間查詢操作的時間複雜度均為 \(O(\log n)\)

建樹的時間複雜度為樹的節點個數 \(O(2\times n- 1)=O(n)\)

Code

點選檢視程式碼
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read();
        int[] a = new int[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = read();
        SegmentTree seg = new SegmentTree(a, n);
        while (m-- != 0) {
            int command = read();
            if (command == 1) {
                int x = read(), k = read();
                seg.updateOne(x, k, 1, 1, n);
            } else {
                int x = read(), y = read();
                out.println(seg.queryRange(x, y, 1, 1, n));
            }
        }
        out.close();
    }
}

class SegmentTree {
    int[] tree;
    int n;

    private void pushUp(int o, int x, int y) {
        tree[o] = tree[x] + tree[y];
    }

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l];
            return;
        }
        int mid = l + r >> 1, x = o << 1, y = x | 1;
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new int[n << 2];
    }

    // 請保證陣列資料下標從1開始
    public SegmentTree(int[] val, int _n) {
        // assert(val.length >= _n);
        this(_n);
        build(1, 1, n, val);
    }

    public void updateOne(final int index, final int val, int o, int l, int r) {
        if (l == r) {
            tree[o] += val;
            return;
        }
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (index <= mid) updateOne(index, val, x, l, mid);
        else updateOne(index, val, y, mid + 1, r);
        pushUp(o, x, y);
    }

    public int queryRange(final int left, final int right, int o, int l, int r) {
        if (left <= l && r <= right) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        int ans = 0;
        if (left <= mid) ans += queryRange(left, right, x, l, mid);
        if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
        return ans;
    }

    public int queryOne(final int index, int o, int l, int r) {
        if (l == r) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (index <= mid) return queryOne(index, x, l, mid);
        return queryOne(index, y, mid + 1, r);
    }
}

進階

區間修改+單點查詢

P3368 【模板】樹狀陣列 2 - 洛谷

樹狀陣列相同,可以使用差分的方式,將區間修改變為兩次單點修改,本文對於該方法暫不做討論

區間修改的分解區間步驟與區間查詢類似:

  • 如果左子樹包含修改區間,即 \(queryLeft\le mid\),則修改左子樹
  • 如果右子樹包含修改區間,即 \(queryRight>mid\),則修改右子樹

在上述操作之後,合併兩個子區間的資料

所需函數引數如下:

  1. 待修改區間左邊界 \(left\)
  2. 待修改區間左邊界 \(right\)
  3. 修改後元素(或增量)的資料 \(val\)
  4. 當前節點的編號,即 \(tree\) 陣列中的索引 \(o\)
  5. 該節點所管轄區間的左邊界 \(l\)
  6. 該節點所管轄區間的右邊界 \(r\)
/**
 * @param left  待修改區間左邊界
 * @param right 待修改區間右邊界
 * @param val   修改後元素(或增量)的資料
 * @param o     當前節點編號
 * @param l     當前節點管轄區間的左邊界
 * @param r     當前節點管轄區間的右邊界
 */
public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
    // 到達葉子節點(管轄區間長度為1)
    if (l == r) {
        tree[o] += val;
        return;
    }
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 如果查詢區間全在左邊
    if (left <= mid) updateRange(left, right, val, x, l, mid);
    // 如果查詢區間全在右邊
    if (right > mid) updateRange(left, right, val, y, mid + 1, r);
    // 合併兩個子區間的資料
    pushUp(o, x, y);
}

可以發現區間修改時的時間複雜度很高,因為需要對 \([left,right]\) 區間內的每一個葉子都修改,時間複雜度與修改路徑上的節點個數有關,最壞時間複雜度為 \(O(2\times n-1)=O(n)\)

懶惰標記

為了降低區間修改的時間複雜度,讓區間修改的形式與區間查詢的形式相同(即直接修改區間,不修改單個的值),每個節點上多攜帶懶惰標記這個資訊

原理:不用的話我就不修改,只在用的時候(查詢)修改

標記:本區間已經被更新過了,但是子區間卻沒有被更新過,被更新的資訊是什麼。

當路過這個節點時,加上這個標記的值

public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
    if (left <= l && r <= right) {
        tree[o] += val * (r - l + 1);
        // 攜帶上val這個資訊,表明該子樹均未修改
        tag[o] += val;
        return;
    }
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    if (left <= mid) updateRange(left, right, val, x, l, mid);
    if (right > mid) updateRange(left, right, val, y, mid + 1, r);
    pushUp(o, x, y);
}
public int queryOne(final int index, int o, int l, int r) {
    if (l == r) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 加上路徑上的懶惰標記值
    if (index <= mid) return tag[o] + queryOne(index, x, l, mid);
    return tag[o] + queryOne(index, y, mid + 1, r);
}

Code

點選檢視程式碼
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read();
        int[] a = new int[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = read();
        SegmentTree seg = new SegmentTree(a, n);
        while (m-- != 0) {
            int command = read();
            if (command == 1) {
                int x = read(), y = read(), k = read();
                seg.updateRange(x, y, k, 1, 1, n);
            } else {
                int x = read();
                out.println(seg.queryOne(x, 1, 1, n));
            }
        }
        out.close();
    }
}

class SegmentTree {
    int[] tree, tag;
    int n;

    private void pushUp(int o, int x, int y) {
        tree[o] = tree[x] + tree[y];
    }

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l];
            return;
        }
        int mid = l + r >> 1, x = o << 1, y = x | 1;
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new int[n << 2];
        tag = new int[n << 2];
    }

    // 請保證陣列資料下標從1開始
    public SegmentTree(int[] val, int _n) {
        // assert(val.length >= _n);
        this(_n);
        build(1, 1, n, val);
    }

    public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
        if (left <= l && r <= right) {
            tree[o] += val * (r - l + 1);
            tag[o] += val;
            return;
        }
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (left <= mid) updateRange(left, right, val, x, l, mid);
        if (right > mid) updateRange(left, right, val, y, mid + 1, r);
        pushUp(o, x, y);
    }

    public int queryOne(final int index, int o, int l, int r) {
        if (l == r) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (index <= mid) return tag[o] + queryOne(index, x, l, mid);
        return tag[o] + queryOne(index, y, mid + 1, r);
    }
}

區間修改+區間查詢

P3372 【模板】線段樹 1 - 洛谷

當進行區間修改及區間查詢或多種複雜操作時,可能會覺得直接套用例題中的區間查詢和進階中的區間修改就行。

但事實上不是這樣的,一個錯誤程式碼如下:

private void pushUp(int o, int x, int y) {
    tree[o] = tree[x] + tree[y];
}
public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
    if (left <= l && r <= right) {
        tree[o] += val * (r - l + 1);
        tag[o] += val;
        return;
    }
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    if (left <= mid) updateRange(left, right, val, x, l, mid);
    if (right > mid) updateRange(left, right, val, y, mid + 1, r);
    pushUp(o, x, y);
}
public int queryRange(final int left, final int right, int o, int l, int r) {
    if (left <= l && r <= right) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 算上 該區間 的tag標記
    int ans = tag[o] * (Math.min(r, right) - Math.max(l, left) + 1);
    if (left <= mid) ans += queryRange(left, right, x, l, mid);
    if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
    return ans;
}

\(\{0,0,0,0\}\) 的陣列舉例,初始線段樹圖示如下:

\(1\sim4\) 區間 \(+1\),得:

\(1\sim2\) 區間 \(+1\) 得:

此時發現不對了,\(1\) 號節點的資料被修改了,結果不正確

原因如下:

  • 第一次修改時,對 \(1\sim4\) 的區間 \(+1\) 並沒有傳到子節點,子節點的值沒有發生改變

  • 第二次修改時,對 \(1\sim2\) 的區間 \(+1\)後,呼叫 \(pushUp\) 資料上傳,\(1\) 號節點的資料就不正確了

下面有兩種方式解決該問題

標記永久化

標記永久化:在修改時修改路徑上被影響的節點,在詢問時累加路徑上的標記

區間修改:將路徑上的影響計算到線段樹的 \(data\)

區間查詢:累加查詢路徑上的 \(tag\)(有效區間內的)

public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
    if (left <= l && r <= right) {
        tree[o] += val * (r - l + 1);
        tag[o] += val;
        return;
    }
    // 將後續修改的影響計算到當前節點中
    tree[o] += val * (Math.min(r, right) - Math.max(l, left) + 1);
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    if (left <= mid) updateRange(left, right, val, x, l, mid);
    if (right > mid) updateRange(left, right, val, y, mid + 1, r);
    // 不用pushUp操作
}
public long queryRange(final int left, final int right, int o, int l, int r) {
    if (left <= l && r <= right) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 累加路徑上的tag標記(有效區間內的)
    long ans = tag[o] * (Math.min(r, right) - Math.max(l, left) + 1);
    if (left <= mid) ans += queryRange(left, right, x, l, mid);
    if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
    return ans;
}

Code

點選檢視程式碼
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read();
        int[] a = new int[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = read();
        SegmentTree seg = new SegmentTree(a, n);
        while (m-- != 0) {
            int command = read();
            if (command == 1) {
                int x = read(), y = read(), k = read();
                seg.updateRange(x, y, k, 1, 1, n);
            } else {
                int x = read(), y = read();
                out.println(seg.queryRange(x, y, 1, 1, n));
            }
        }
        out.close();
    }
}

class SegmentTree {
    long[] tree, tag;
    int n;

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l];
            return;
        }
        int mid = l + r >> 1, x = o << 1, y = x | 1;
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        tree[o] = tree[x] + tree[y];
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new long[n << 2];
        tag = new long[n << 2];
    }

    // 請保證陣列資料下標從1開始
    public SegmentTree(int[] val, int _n) {
        // assert(val.length >= _n);
        this(_n);
        build(1, 1, n, val);
    }

    public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
        if (left <= l && r <= right) {
            tree[o] += val * (r - l + 1);
            tag[o] += val;
            return;
        }
        // 將後續修改的影響計算到當前節點中
        tree[o] += val * (Math.min(r, right) - Math.max(l, left) + 1);
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (left <= mid) updateRange(left, right, val, x, l, mid);
        if (right > mid) updateRange(left, right, val, y, mid + 1, r);
        // 不用pushUp操作
    }

    public long queryRange(final int left, final int right, int o, int l, int r) {
        if (left <= l && r <= right) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        // 累加路徑上的tag標記(有效區間內的)
        long ans = tag[o] * (Math.min(r, right) - Math.max(l, left) + 1);
        if (left <= mid) ans += queryRange(left, right, x, l, mid);
        if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
        return ans;
    }
}

標記下傳

標記下傳:把一個節點的懶惰標記傳給它的左右兒子,再把該節點的懶惰標記刪去

當執行到某一節點時,先下傳當前節點的標記,再查詢或更新,最後 \(pushUp\) 的就是正確結果

// 標記下傳(若下方標記與區間邊界無關,則不需要l,r引數)
private void pushDown(int o, int x, int y, int l, int r) {
    // 空標記直接退出
    if (tag[o] == 0) return;
    int mid = l + r >> 1;
    // 下傳給左節點
    tag[x] += tag[o];
    tree[x] += tag[o] * (mid - l + 1);
    // 下傳給右節點
    tag[y] += tag[o];
    tree[y] += tag[o] * (r - (mid + 1) + 1);
    // 清空當前節點標記
    tag[o] = 0;
}
public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
    if (left <= l && r <= right) {
        tree[o] += val * (r - l + 1);
        tag[o] += val;
        return;
    }
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 將後續修改的影響計算到當前節點中
    // 下放標記
    pushDown(o, x, y, l, r);
    if (left <= mid) updateRange(left, right, val, x, l, mid);
    if (right > mid) updateRange(left, right, val, y, mid + 1, r);
    // 上傳子區間資料
    pushUp(o, x, y);
}
public long queryRange(final int left, final int right, int o, int l, int r) {
    if (left <= l && r <= right) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 將後續修改的影響計算到當前節點中
    // 下放標記
    pushDown(o, x, y, l, r);
    long ans = 0;
    if (left <= mid) ans += queryRange(left, right, x, l, mid);
    if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
    // 上傳子區間資料
    pushUp(o, x, y);
    return ans;
}

Code

點選檢視程式碼
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read();
        int[] a = new int[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = read();
        SegmentTree seg = new SegmentTree(a, n);
        while (m-- != 0) {
            int command = read();
            if (command == 1) {
                int x = read(), y = read(), k = read();
                seg.updateRange(x, y, k, 1, 1, n);
            } else {
                int x = read(), y = read();
                out.println(seg.queryRange(x, y, 1, 1, n));
            }
        }
        out.close();
    }
}

class SegmentTree {
    long[] tree, tag;
    int n;

    // 標記下傳
    private void pushDown(int o, int x, int y, int l, int r) {
        // 空標記直接退出
        if (tag[o] == 0) return;
        int mid = l + r >> 1;
        // 下傳給左節點
        tag[x] += tag[o];
        tree[x] += tag[o] * (mid - l + 1);
        // 下傳給右節點
        tag[y] += tag[o];
        tree[y] += tag[o] * (r - (mid + 1) + 1);
        // 清空當前節點標記
        tag[o] = 0;
    }

    private void pushUp(int o, int x, int y) {
        tree[o] = tree[x] + tree[y];
    }

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l];
            return;
        }
        int mid = l + r >> 1, x = o << 1, y = x | 1;
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new long[n << 2];
        tag = new long[n << 2];
    }

    // 請保證陣列資料下標從1開始
    public SegmentTree(int[] val, int _n) {
        // assert(val.length >= _n);
        this(_n);
        build(1, 1, n, val);
    }

    public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
        if (left <= l && r <= right) {
            tree[o] += val * (r - l + 1);
            tag[o] += val;
            return;
        }
        // 將後續修改的影響計算到當前節點中
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        // 下放標記
        pushDown(o, x, y, l, r);
        if (left <= mid) updateRange(left, right, val, x, l, mid);
        if (right > mid) updateRange(left, right, val, y, mid + 1, r);
        // 上傳子區間資料
        pushUp(o, x, y);
    }

    public long queryRange(final int left, final int right, int o, int l, int r) {
        if (left <= l && r <= right) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        // 下放標記
        pushDown(o, x, y, l, r);
        long ans = 0;
        if (left <= mid) ans += queryRange(left, right, x, l, mid);
        if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
        // 上傳子區間資料
        pushUp(o, x, y);
        return ans;
    }
}

空間優化

下面給出P3372 【模板】線段樹 1 - 洛谷空間優化的線段樹類程式碼

\(2n\) 空間

對於上述的線段樹,要用到 \(4n\) 的空間,但只有 \(2n-1\) 個空間有作用,能不能只建立 \(2n\) 個空間?

深度優先搜尋 \(DFS\) 是樹的一種遍歷方式,而 \(DFS\) 序是深度優先搜尋中的節點存取次序,記為 \(DFN\),選擇按照 \(DFN\) 的方式儲存線段樹節點

若某一個節點的編號為 \(p\),則其左兒子節點編號為 \(p+1\),則其右兒子節點編號為 \(p+左子樹節點個數+1\)(因為是先遍歷左子樹嘛)

那左子樹節點個數該怎麼求呢?

線上段樹的儲存中提到

一顆管理陣列長度為 \(n\) 的線段樹的節點個數為 \(2\times n -1\)

若當前節點管理區間為 \([l,r]\),設 \(mid=\lfloor\dfrac{l+r}{2}\rfloor\),則左子樹管理區間為 \([l,mid]\),左子樹管理區間長度為 \(mid-l+1\),所以左子樹節點個數為 \(2\times(\lfloor\dfrac{l+r}{2}\rfloor-l+1)-1=2\times\lfloor\dfrac{r-l+2}{2}\rfloor-1\)

因此,右兒子節點編號為 \(p+2\times\lfloor\dfrac{r-l+2}{2}\rfloor\)

\(2\) 向下取整代表按位元右移,乘 \(2\) 代表按位元左移。

因此,\(2\times\lfloor\dfrac{r-l+2}{2}\rfloor\) 可用 \(r-l+2\) 並將二進位制最低位置為 \(0\) 表示

(r - l + 2) & ~1

點選檢視程式碼
class SegmentTree {
    long[] tree, tag;
    int n;

    private void pushUp(int o, int x, int y) {
        tree[o] = tree[x] + tree[y];
    }

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l];
            return;
        }
        final int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new long[n << 1];
        tag = new long[n << 1];
    }

    // 請保證陣列資料下標從1開始
    public SegmentTree(int[] val, int _n) {
        // assert(val.length >= _n);
        this(_n);
        build(1, 1, n, val);
    }

    private void pushDown(int o, int x, int y, int l, int r) {
        if (tag[o] == 0) return;
        int mid = l + r >> 1;
        tag[x] += tag[o];
        tree[x] += tag[o] * (mid - l + 1);
        tag[y] += tag[o];
        tree[y] += tag[o] * (r - mid);
        tag[o] = 0;
    }

    public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
        if (left <= l && r <= right) {
            tree[o] += val * (r - l + 1);
            tag[o] += val;
            return;
        }
        final int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) updateRange(left, right, val, x, l, mid);
        if (right > mid) updateRange(left, right, val, y, mid + 1, r);
        pushUp(o, x, y);
    }

    public long queryRange(final int left, final int right, int o, int l, int r) {
        if (left <= l && r <= right) return tree[o];
        final int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        long ans = 0;
        if (left <= mid) ans += queryRange(left, right, x, l, mid);
        if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
        pushUp(o, x, y);
        return ans;
    }
}

動態開點

所有節點只在使用時才\(new\)申請記憶體,\(Cpp\) 通過指標、\(Java\) 通過參照的方式動態開點

如果動態開點進行 \(build\) 建樹操作,就會將所有節點建立出來,就和上面 \(2n\) 空間一樣了

因此,動態開點一般是不建樹的

數的加法是不影響初值的,我們將初值取出,只對全為 \(0\) 的線段樹進行區間加法和查詢

在總的查詢的時候加上原陣列的值即可

最壞的情況就是所有操作都走不同的路徑、且走到葉子節點,則一次操作會增加 \(\log n\) 個節點

\(n\) 為查詢區間長度,\(m\) 為詢問操作次數,則空間複雜度為:\(O(min(2n-1,\ m\log n))\)

點選檢視程式碼
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read();
        long[] a = new long[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = a[i - 1] + read();
        SegmentTree seg = new SegmentTree(n);
        while (m-- != 0) {
            int command = read();
            if (command == 1) {
                int x = read(), y = read(), k = read();
                seg.modify(x, y, k);
            } else {
                int x = read(), y = read();
                long ans = a[y] - a[x - 1];
                ans += seg.query(x, y).sum;
                out.println(ans);
            }
        }
        out.close();
    }
}

class SegmentTree {
    class node {
        // 設定節點預設空白初始值, 用於答案查詢及建立節點
        long sum = 0, add = 0;
        node lChild, rChild;

        // val 加到data和tag上, 用於區間修改終止和標記下傳
        // 按照需要選擇是否需要左右邊界
        void apply(int l, int r, final long val) {
            sum += (r - l + 1) * val;
            add += val;
        }

        // 建立兒子節點
        public void addNode() {
            if (lChild == null) lChild = new node();
            if (rChild == null) rChild = new node();
        }
    }

    // 標記下傳, 將cur節點的標記下傳至兩個子樹中
    // 按照需要選擇是否需要左右邊界
    void pushDown(node cur, int l, int r) {
        if (cur.add != 0) {
            int mid = l + r >> 1;
            cur.lChild.apply(l, mid, cur.add);
            cur.rChild.apply(mid + 1, r, cur.add);
            cur.add = 0;
        }
    }

    // son的data資料加到cur上, 用於pushUp上傳資料 和 查詢時合併答案
    // 不用理會標記數值(前提是node有預設初始值,且代表空標記)
    void unite(node cur, final node son) {
        cur.sum += son.sum;
    }

    // 子區間資料上傳
    void pushUp(node cur) {
        cur.sum = 0;
        unite(cur, cur.lChild);
        unite(cur, cur.rChild);
    }

    private void modify(node cur, int l, int r, final int left, final int right, final long val) {
        if (left <= l && r <= right) {
            cur.apply(l, r, val);
            return;
        }
        cur.addNode();
        pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) modify(cur.lChild, l, mid, left, right, val);
        if (right > mid) modify(cur.rChild, mid + 1, r, left, right, val);
        pushUp(cur);
    }

    private void query(node cur, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, cur);
            return;
        }
        cur.addNode();
        pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) query(cur.lChild, l, mid, left, right, res);
        if (right > mid) query(cur.rChild, mid + 1, r, left, right, res);
        pushUp(cur);
    }

    int n;
    node root;

    // 請保證node處進行了預設初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        root = new node();
    }

    // 區間修改
    public void modify(int left, int right, final int val) {
        // assert (1 <= left && left <= right && right <= n);
        modify(root, 1, n, left, right, val);
    }
    
    // 區間查詢
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(root, 1, n, left, right, res);
        return res;
    }
}

要點總結

如果覺得我講的不是很明白,可以看參考資料中提到的文章或者 \(Bilibili\)

  • 樹狀陣列不同,線段樹只需要兩個區間資訊可以合併,即維護的資訊只需要滿足結合律(如加法、乘法、互斥或、最大公約數等)
    結合律:\((x\circ y)\circ z=x\circ(y\circ z)\),其中 \(\circ\) 是一個二元運運算元。

  • 帶懶惰標記的線段樹修改和查詢操作時間複雜度均為 \(O(\log n)\),建樹時間複雜度為 \(O(n)\)

  • 對標記的兩種操作各有各的優點

    標記下傳的實用性更廣

    標記永久化思想還可用於可持久化資料結構

  • 線段樹對於非強制線上的問題可以通過離散化縮小資料範圍來減少空間

    而對於強制線上的問題就只能通過動態開點來減少空間了(應該

    當然,離散化也可以和動態開點搭配

    強制線上:不提前給出所有涉及詢問和修改的區間範圍,不能進行離散化

  • 線段樹是一種工具,許多問題可以藉助這個工具解決,就如同滑動視窗可以藉助雙端佇列解決一樣

線段樹封裝類

陣列有效資料下標均從 \(1\) 開始

基礎的不帶懶標記的線段樹暫時不提供了,因為對於這類問題,樹狀陣列大多可以解決

如果您有更好的封裝類能提供給我,我將感激不盡

懶標記線段樹

C++

點選檢視程式碼
template <typename T>
class SegmentTree {
public:
    struct node {
        // 設定葉子節點預設初始值, 用於不傳陣列的建樹以及空標記
        T data = ...;
        T tag = ...;
        // val 加到data和tag上, 用於區間修改終止和標記下傳
        // 按照需要選擇是否需要左右邊界
        void apply(..., const T &val) {
            ...
            // sum += (r - l + 1) * val;
            // add += val;
        }
        // 建樹時傳入陣列的初始化
        void init(const T &val) {
            ...
            // sum = val;
        }
    };
    // 標記下傳, 將o節點的標記下傳至兩個子樹x,y中
    // 按照需要選擇是否需要左右邊界
    void pushDown(int o, int x, int y) {
        ...
        // if (tree[o].add != 0) {
        //     int mid = l + r >> 1;
        //     // 下傳標記至左子樹
        //     tree[x].apply(l, mid, tree[o].add);
        //     // 下傳標記至右子樹
        //     tree[y].apply(mid + 1, r, tree[o].add);
        //     // 清空當前節點標記
        //     tree[o].add = 0;
        // }
    }
    // son的data資料加到o上, 用於pushUp上傳資料 和 查詢時合併答案
    // 不用理會標記數值(前提是node有預設初始值,且代表空標記)
    void unite(node& o, const node& son) {
        ...
        // o.sum += son.sum;
    }
    // 子區間資料上傳
    void pushUp(int o, int x, int y) {
        // 清空當前節點data
        ...
        // tree[o].sum = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }
    void build(int o, int l, int r) {
        if (l == r) return;
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid);
        build(y, mid + 1, r);
        pushUp(o, x, y);
    }
    void build(int o, int l, int r, const std::vector<T> &val) {
        if (l == r) {
            tree[o].init(val[l]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }
    void modify(int o, int l, int r, const int &left, const int &right, const T &val) {
        if (left <= l && r <= right) {
            tree[o].apply(..., val);
            // tree[o].apply(l, r, val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y);
        // pushDown(o, x, y, l, r);
        if (left <= mid) modify(x, l, mid, left, right, val);
        if (right > mid) modify(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }
    void query(int o, int l, int r, const int& left, const int& right, node& res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }
    int n;
    std::vector<node> tree;
    // 不傳入陣列的預設建樹, 請保證node處進行了預設初始化
    SegmentTree(int _n) : n(_n) {
        assert(n > 0);
        tree.resize(n << 1);
        build(1, 1, n);
    }
    // 傳入陣列的建樹, 請保證陣列有效資料下標從1開始
    SegmentTree(std::vector<T>& val, int _n) : n(_n) {
        assert((int)val.size() >= _n);
        tree.resize(n << 1);
        build(1, 1, n, val);
    }
    // 單點修改
    void modify(const int& index, const T& val) {
        assert(1 <= index && index <= n);
        modify(1, 1, n, index, index, val);
    }
    // 區間修改
    void modify(const int& left, const int& right, const T& val) {
        assert(1 <= left && left <= right && right <= n);
        modify(1, 1, n, left, right, val);
    }
    // 單點查詢
    node query(const int& index) {
        assert(1 <= index && index <= n);
        node res{};
        query(1, 1, n, index, index, res);
        return res;
    }
    // 區間查詢
    node query(const int& left, const int& right) {
        assert(1 <= left && left <= right && right <= n);
        node res{};
        query(1, 1, n, left, right, res);
        return res;
    }
};

Java

由於本人對於 \(Java\) 泛型還不熟悉,以後學明白了再修改(挖坑,希望會填

點選檢視程式碼
class SegmentTree {
    class node {
        int data, tag;

        public node(int _data, int _tag) {
            data = _data;
            tag = _tag;
        }

        // 建樹時傳入陣列的初始化(標記置空)
        public node(int _data) {this(_data, 0);}

        // 設定葉子節點預設初始值, 用於不傳陣列的建樹、置空標記以及查詢答案的初始化
        public node() {this(0, 0);}

        // val 加到data和tag上, 用於區間修改終止和標記下傳
        // 按照需要選擇是否需要左右邊界
        void apply(...,final int val) {
            ...
            // sum += (r - l + 1) * val;
            // add += val;
        }
    }

    // 標記下傳, 將o節點的標記下傳至兩個子樹x,y中
    // 按照需要選擇是否需要左右邊界
    private void pushDown(int o, int x, int y) {
        ...
        // if (tree[o].add != 0) {
        //     int mid = l + r >> 1;
        //     // 下傳標記至左子樹
        //     tree[x].apply(l, mid, tree[o].add);
        //     // 下傳標記至右子樹
        //     tree[y].apply(mid + 1, r, tree[o].add);
        //     // 清空當前節點標記
        //     tree[o].add = 0;
        // }
    }

    // son的data資料加到o上, 用於pushUp上傳資料 和 查詢時合併答案
    // 不用理會標記數值(前提是node有預設初始值,且代表空標記)
    private void unite(node o, final node son) {
        ...
        // o.sum += son.sum;
    }

    // 子區間資料上傳
    private void pushUp(int o, int x, int y) {
        // 清空當前節點data
        ...
        // tree[o].sum = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }

    public void build(int o, int l, int r) {
        if (l == r) {
            tree[o] = new node();
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid);
        build(y, mid + 1, r);
        pushUp(o, x, y);
    }

    public void build(int o, int l, int r, final int[] val) {
        if (l == r) {
            tree[o] = new node(val[l]);
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    private void modify(int o, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
            tree[o].apply(...,val);
            // tree[o].apply(l, r, val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y);
        // pushDown(o, x, y, l, r);
        if (left <= mid) modify(x, l, mid, left, right, val);
        if (right > mid) modify(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }

    private void query(int o, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y);
        // pushDown(o, x, y, l, r);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }

    private int n;
    private node[] tree;
    
    // 不傳入陣列的預設建樹, 請保證node處進行了預設初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        tree = new node[n << 1];
        build(1, 1, n);
    }

    // 傳入陣列的建樹, 請保證陣列有效資料下標從1開始
    public SegmentTree(final int[] val, int _n) {
        // assert ((int) val.length >= _n);
        n = _n;
        tree = new node[n << 1];
        build(1, 1, n, val);
    }

    // 單點修改
    public void modify(int index, final int val) {
        // assert (1 <= index && index <= n);
        modify(1, 1, n, index, index, val);
    }

    // 區間修改
    public void modify(int left, int right, final int val) {
        // assert (1 <= left && left <= right && right <= n);
        modify(1, 1, n, left, right, val);
    }

    // 單點查詢
    public node query(int index) {
        // assert (1 <= index && index <= n);
        node res = new node();
        query(1, 1, n, index, index, res);
        return res;
    }

    // 區間查詢
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(1, 1, n, left, right, res);
        return res;
    }
}

動態開點的懶標記線段樹

Java

點選檢視程式碼
class SegmentTree {
    class node {
        // 設定節點預設空白初始值, 用於答案查詢及建立節點
        int data = 0, tag = 0;
        node lChild, rChild;

        // val 加到data和tag上, 用於區間修改終止和標記下傳
        // 按照需要選擇是否需要左右邊界
        private void apply(...,final int val) {
            ...
            // sum += (r - l + 1) * val;
            // add += val;
        }

        // 建立子節點
        private void addNode() {
            if (lChild == null) lChild = new node();
            if (rChild == null) rChild = new node();
        }
    }

    // 標記下傳, 將cur節點的標記下傳至兩個子樹中
    // 按照需要選擇是否需要左右邊界
    private void pushDown(node cur, ...) {
        ...
        // if (cur.add != 0) {
        //     int mid = l + r >> 1;
        //     // 下傳標記至左子樹
        //     cur.lChild.apply(l, mid, cur.add);
        //     // 下傳標記至右子樹
        //     cur.rChild.apply(mid + 1, r, cur.add);
        //     // 清空當前節點標記
        //     cur.add = 0;
        // }
    }

    // son的data資料加到cur上, 用於pushUp上傳資料 和 查詢時合併答案
    // 不用理會標記數值(前提是node有預設初始值,且代表空標記)
    private void unite(node cur, final node son) {
        ...
        // cur.sum += son.sum;
    }

    // 子區間資料上傳
    private void pushUp(node cur) {
        // 清空當前節點data
        ...
        // cur.sum = 0;
        unite(cur, cur.lChild);
        unite(cur, cur.rChild);
    }

    private void modify(node cur, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
            cur.apply(...,val);
            // cur.apply(l, r, val);
            return;
        }
        cur.addNode();
        pushDown(cur);
        // pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) modify(cur.lChild, l, mid, left, right, val);
        if (right > mid) modify(cur.rChild, mid + 1, r, left, right, val);
        pushUp(cur);
    }

    private void query(node cur, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, cur);
            return;
        }
        cur.addNode();
        pushDown(cur);
        // pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) query(cur.lChild, l, mid, left, right, res);
        if (right > mid) query(cur.rChild, mid + 1, r, left, right, res);
        pushUp(cur);
    }

    private int n;
    private node root;

    // 請保證node處進行了預設初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        root = new node();
    }

    // 單點修改
    public void modify(int index, final int val) {
        // assert (1 <= index && index <= n);
        modify(root, 1, n, index, index, val);
    }

    // 區間修改
    public void modify(int left, int right, final int val) {
        // assert (1 <= left && left <= right && right <= n);
        modify(root, 1, n, left, right, val);
    }

    // 單點查詢
    public node query(int index) {
        // assert (1 <= index && index <= n);
        node res = new node();
        query(root, 1, n, index, index, res);
        return res;
    }

    // 區間查詢
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(root, 1, n, left, right, res);
        return res;
    }
}

C++

箭頭操作符\(=\)解除參照\(+\)點操作符

p->data(*p).data相同

點選檢視程式碼
template <typename T>
class SegmentTree {
private:
    struct node {
        // 設定節點預設空白初始值, 用於答案查詢及建立節點
        T data = ..., tag = ...;
        node* lChild = nullptr;
        node* rChild = nullptr;
        // val 加到data和tag上, 用於區間修改終止和標記下傳
        // 按照需要選擇是否需要左右邊界
        void apply(..., const T& val) {
            ...
            // sum += (r - l + 1) * val;
            // add += val;
        }
        // 建立兒子節點
        void addNode() {
            if (!lChild) lChild = new node();
            if (!rChild) rChild = new node();
        }
    };
    // 標記下傳, 將cur節點的標記下傳至兩個子樹中
    // 按照需要選擇是否需要左右邊界
    void pushDown(node* cur, ...) {
        ...
        // if (cur->add != 0) {
        //     int mid = l + r >> 1;
        //     cur->lChild->apply(l, mid, cur->add);
        //     cur->rChild->apply(mid + 1, r, cur->add);
        //     cur->add = 0;
        // }
    }

    // son的data資料加到cur上, 用於pushUp上傳資料 和 查詢時合併答案
    // 不用理會標記數值(前提是node有預設初始值,且代表空標記)
    void unite(node* cur, const node* son) {
        ...
        // cur->sum += son->sum;
    }

    // 子區間資料上傳
    void pushUp(node* cur) {
        // 清空當前節點data
        ...
        // cur->sum = 0;
        unite(cur, cur->lChild);
        unite(cur, cur->rChild);
    }
    void modify(node* cur, int l, int r, const int& left, const int& right, const T& val) {
        if (left <= l && r <= right) {
            cur->apply(..., val);
            // cur->apply(l, r, val);
            return;
        }
        cur->addNode();
        pushDown(cur);
        // pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) modify(cur->lChild, l, mid, left, right, val);
        if (right > mid) modify(cur->rChild, mid + 1, r, left, right, val);
        pushUp(cur);
    }

    void query(node* cur, int l, int r, const int& left, const int& right, node* res) {
        if (left <= l && r <= right) {
            unite(res, cur);
            return;
        }
        cur->addNode();
        pushDown(cur);
        // pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) query(cur->lChild, l, mid, left, right, res);
        if (right > mid) query(cur->rChild, mid + 1, r, left, right, res);
        pushUp(cur);
    }

    int n;
    node* root;

public:
    // 請保證node處進行了預設初始化
    SegmentTree(const int& _n) : n(_n), root(new node()) {
        assert(n > 0);
    }
    // 單點修改
    void modify(const int& index, const T& val) {
        assert(1 <= index && index <= n);
        modify(root, 1, n, index, index, val);
    }
    // 區間修改
    void modify(const int& left, const int& right, const T& val) {
        assert(1 <= left && left <= right && right <= n);
        modify(root, 1, n, left, right, val);
    }
    // 單點查詢
    node* query(const int& index) {
        assert(1 <= index && index <= n);
        node* res = new node();
        query(root, 1, n, index, index, res);
        return res;
    }
    // 區間查詢
    node* query(const int& left, const int& right) {
        assert(1 <= left && left <= right && right <= n);
        node* res = new node();
        query(root, 1, n, left, right, res);
        return res;
    }
};

題目

P3373 線段樹2 - 洛谷

題目連結

題意簡述:有三個操作

  1. 對區間 \([l,r]\) 每個數乘上 \(k\)
  2. 對區間 \([l,r]\) 每個數加上 \(k\)
  3. 查詢區間 \([l,r]\) 每個數的和

兩種修改操作對應兩種懶惰標記, 優先下傳乘法標記,將乘法標記對加法標記的影響算在加法標記中

Java Code

點選檢視程式碼
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read(), mod = read();
        long[] a = new long[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = read();
        SegmentTree seg = new SegmentTree(a, n);
        while (m-- != 0) {
            int command = read(), x = read(), y = read();
            if (command == 1) seg.modifyMul(x, y, read());
            else if (command == 2) seg.modifyAdd(x, y, read());
            else out.println(seg.query(x, y).sum);
        }
        out.close();
    }
}

class SegmentTree {
    static final int mod = 571373;

    class node {
        long sum, add, mul;

        public node(long _sum, long _add, long _mul) {
            sum = _sum % mod;
            add = _add % mod;
            mul = _mul % mod;
        }

        public node(long _sum) {this(_sum, 0, 1);}

        public node() {this(0, 0, 1);}

        // val 加到data和tag上, 用於區間修改終止和標記下傳
        // 按照需要選擇是否需要左右邊界
        void applyAdd(int l, int r, final long val) {
            sum = (sum + (r - l + 1) * val) % mod;
            add = (add + val) % mod;
        }

        void applyMul(final long val) {
            sum = sum * val % mod;
            mul = mul * val % mod;
            add = add * val % mod;
        }
    }

    void pushDown(int o, int x, int y, int l, int r) {
        if (tree[o].mul != 1) {
            tree[x].applyMul(tree[o].mul);
            tree[y].applyMul(tree[o].mul);
            tree[o].mul = 1;
        }
        if (tree[o].add != 0) {
            int mid = l + r >> 1;
            tree[x].applyAdd(l, mid, tree[o].add);
            tree[y].applyAdd(mid + 1, r, tree[o].add);
            tree[o].add = 0;
        }
    }

    void unite(node o, final node son) {
        o.sum = (o.sum + son.sum) % mod;
    }

    void pushUp(int o, int x, int y) {
        tree[o].sum = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }

    public void build(int o, int l, int r, final long[] val) {
        if (l == r) {
            tree[o] = new node(val[l]);
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    private void modifyAdd(int o, int l, int r, final int left, final int right, final long val) {
        if (left <= l && r <= right) {
            tree[o].applyAdd(l, r, val);
            // tree[o].apply(l, r, val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) modifyAdd(x, l, mid, left, right, val);
        if (right > mid) modifyAdd(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }

    private void modifyMul(int o, int l, int r, final int left, final int right, final long val) {
        if (left <= l && r <= right) {
            tree[o].applyMul(val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) modifyMul(x, l, mid, left, right, val);
        if (right > mid) modifyMul(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }

    private void query(int o, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }

    int n;
    node[] tree;

    public SegmentTree(final long[] val, int _n) {
        // assert ((int) val.length >= _n);
        n = _n;
        tree = new node[n << 1];
        build(1, 1, n, val);
    }

    // 區間修改1
    void modifyMul(final int left, final int right, final long val) {
        assert (1 <= left && left <= right && right <= n);
        modifyMul(1, 1, n, left, right, val);
    }

    // 區間修改2
    void modifyAdd(int left, int right, final long val) {
        // assert (1 <= left && left <= right && right <= n);
        modifyAdd(1, 1, n, left, right, val);
    }

    // 區間查詢
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(1, 1, n, left, right, res);
        return res;
    }
}

C++20 Code

用到了自動取模類,詳見逆元一文

點選檢視程式碼
#include <bits/stdc++.h>
using namespace std;

template <typename T>
class SegmentTree {
public:
    struct node {
        // 設定葉子節點預設初始值, 用於不傳陣列的建樹以及空標記
        T sum = 0;
        T add = 0;
        T mul = 1;
        // val 加到data和tag上, 用於區間修改終止和標記下傳
        // 按照需要選擇是否需要左右邊界
        void applyAdd(int l, int r, const T& val) {
            sum += T(r - l + 1) * val;
            add += val;
        }
        void applyMultiply(const T& val) {
            sum *= val;
            mul *= val;
            add *= val;
        }
        // 建樹時傳入陣列的初始化
        void init(const T& val) {
            sum = val;
        }
    };
    // 標記下傳, 將o節點的標記下傳至兩個子樹x,y中
    // 按照需要選擇是否需要左右邊界
    void pushDown(int o, int x, int y, int l, int r) {
        if (tree[o].mul != 1) {
            tree[x].applyMultiply(tree[o].mul);
            tree[y].applyMultiply(tree[o].mul);
            tree[o].mul = 1;
        }
        if (tree[o].add != 0) {
            int mid = l + r >> 1;
            tree[x].applyAdd(l, mid, tree[o].add);
            tree[y].applyAdd(mid + 1, r, tree[o].add);
            tree[o].add = 0;
        }
    }
    // son的data資料加到o上, 用於pushUp上傳資料 和 查詢時合併答案
    // 不用理會標記數值(前提是node有預設初始值,且代表空標記)
    void unite(node& o, const node& son) {
        o.sum += son.sum;
    }
    // 子區間資料上傳
    void pushUp(int o, int x, int y) {
        tree[o].sum = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }
    void build(int o, int l, int r, const std::vector<T>& val) {
        if (l == r) {
            tree[o].init(val[l]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }
    void modifyAdd(int o, int l, int r, const int& left, const int& right, const T& val) {
        if (left <= l && r <= right) {
            tree[o].applyAdd(l, r, val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) modifyAdd(x, l, mid, left, right, val);
        if (right > mid) modifyAdd(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }
    void modifyMultiply(int o, int l, int r, const int& left, const int& right, const T& val) {
        if (left <= l && r <= right) {
            tree[o].applyMultiply(val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) modifyMultiply(x, l, mid, left, right, val);
        if (right > mid) modifyMultiply(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }
    void query(int o, int l, int r, const int& left, const int& right, node& res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }
    int n;
    std::vector<node> tree;
    // 傳入陣列的建樹, 請保證陣列有效資料下標從1開始
    SegmentTree(std::vector<T>& val, int _n) : n(_n) {
        assert((int)val.size() >= _n);
        tree.resize(n << 1);
        build(1, 1, n, val);
    }
    // 區間修改1
    void modifyMultiply(const int& left, const int& right, const T& val) {
        assert(1 <= left && left <= right && right <= n);
        modifyMultiply(1, 1, n, left, right, val);
    }
    // 區間修改2
    void modifyAdd(const int& left, const int& right, const T& val) {
        assert(1 <= left && left <= right && right <= n);
        modifyAdd(1, 1, n, left, right, val);
    }
    // 區間查詢
    node query(const int& left, const int& right) {
        assert(1 <= left && left <= right && right <= n);
        node res{};
        query(1, 1, n, left, right, res);
        return res;
    }
};

template <int MOD>
struct modint {
    int val;
    static int norm(const int& x) { return x < 0 ? x + MOD : x; }
    static constexpr int get_mod() { return MOD; }
    modint inv() const {
        assert(val);
        int a = val, b = MOD, u = 1, v = 0, t;
        while (b > 0) t = a / b, swap(a -= t * b, b), swap(u -= t * v, v);
        assert(b == 1);
        return modint(u);
    }
    modint() : val(0) {}
    modint(const int& m) : val(norm(m)) {}
    modint(const long long& m) : val(norm(m % MOD)) {}
    modint operator-() const { return modint(norm(-val)); }
    bool operator==(const modint& o) { return val == o.val; }
    bool operator<(const modint& o) { return val < o.val; }
    modint& operator+=(const modint& o) { return val = (1ll * val + o.val) % MOD, *this; }
    modint& operator-=(const modint& o) { return val = norm(1ll * val - o.val), *this; }
    modint& operator*=(const modint& o) { return val = static_cast<int>(1ll * val * o.val % MOD), *this; }
    modint& operator/=(const modint& o) { return *this *= o.inv(); }
    modint& operator^=(const modint& o) { return val ^= o.val, *this; }
    modint& operator>>=(const modint& o) { return val >>= o.val, *this; }
    modint& operator<<=(const modint& o) { return val <<= o.val, *this; }
    modint operator-(const modint& o) const { return modint(*this) -= o; }
    modint operator+(const modint& o) const { return modint(*this) += o; }
    modint operator*(const modint& o) const { return modint(*this) *= o; }
    modint operator/(const modint& o) const { return modint(*this) /= o; }
    modint operator^(const modint& o) const { return modint(*this) ^= o; }
    modint operator>>(const modint& o) const { return modint(*this) >>= o; }
    modint operator<<(const modint& o) const { return modint(*this) <<= o; }
    friend std::istream& operator>>(std::istream& is, modint& a) {
        long long v;
        return is >> v, a.val = norm(v % MOD), is;
    }
    friend std::ostream& operator<<(std::ostream& os, const modint& a) { return os << a.val; }
    friend std::string tostring(const modint& a) { return std::to_string(a.val); }
    friend modint qpow(const modint& a, const int& b) {
        assert(b >= 0);
        modint x = a, res = 1;
        for (int p = b; p; x *= x, p >>= 1)
            if (p & 1) res *= x;
        return res;
    }
};

constexpr int mod = 571373;
using Mint = modint<mod>;

signed main() {
    std::ios_base::sync_with_stdio(false), std::cin.tie(nullptr), std::cout.tie(nullptr);
    int n, m, command, x, y;
    cin >> n >> m >> x;
    vector<Mint> a(n + 1);
    for (int i = 1; i <= n; ++i) cin >> a[i];
    SegmentTree<Mint> seg(a, n);
    Mint k;
    while (m--) {
        cin >> command >> x >> y;
        if (command == 1) {
            cin >> k;
            seg.modifyMultiply(x, y, k);
        } else if (command == 2) {
            cin >> k;
            seg.modifyAdd(x, y, k);
        } else {
            cout << seg.query(x, y).sum << endl;
        }
    }
    return 0;
}

315. 計算右側小於當前元素的個數 - 力扣

題目連結

題意簡述:求逆序對

求逆序對可以通過歸併排序,也可以通過樹狀陣列\(+\)離散化

樹狀陣列可以做這道題,那線段樹也一定可以

對於線段樹可以選擇動態開點,也可以選擇離散化

這題資料範圍很小,正常做好像也能過

Java 離散化

點選檢視程式碼
class Solution {
    int n, max;
    // 去重
    int adjacentRemove(int[] nums) {
        int slow = 0;
        for (int fast = 1; fast < n; ++fast) {
            if (nums[slow] != nums[fast]) {
                nums[++slow] = nums[fast];
            }
        }
        return slow + 1;
    }
    //離散化
    void lis(int[] a) {
        int[] temp = new int[n];
        System.arraycopy(a, 0, temp, 0, n);
        Arrays.sort(temp);
        max = adjacentRemove(temp);
        for (int i = 0; i < n; ++i) {
            // 對映到 [1, max-1] 的區間上
            a[i] = Arrays.binarySearch(temp,0, max, a[i]) + 1;
        }
    }
    public List<Integer> countSmaller(int[] nums) {
        n = nums.length;
        lis(nums);
        SegmentTree seg = new SegmentTree(max);
        List<Integer> ans = new ArrayList<Integer>(n);
        // 找右側 有多少個 比 當前數 小 的數
        for (int i = n - 1; i >= 0; --i) {
            seg.modify(nums[i], 1);
            if (nums[i] - 1 == 0) ans.add(0);
            else ans.add(seg.query(1, nums[i] - 1).sum);
        }
        Collections.reverse(ans);
        return ans;
    }
}

class SegmentTree {
    class node {
        int sum, add;

        public node(int _sum, int _add) {
            sum = _sum;
            add = _add;
        }

        public node(int _sum) {this(_sum, 0);}

        public node() {this(0, 0);}

        void apply(int l,int r,final int val) {
             sum += (r - l + 1) * val;
             add += val;
        }
    }

    void pushDown(int o, int x, int y,int l,int r) {
         if (tree[o].add != 0) {
             int mid = l + r >> 1;
             tree[x].apply(l, mid, tree[o].add);
             tree[y].apply(mid + 1, r, tree[o].add);
             tree[o].add = 0;
         }
    }

    void unite(node o, final node son) {
         o.sum += son.sum;
    }

    void pushUp(int o, int x, int y) {
         tree[o].sum = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }

    public void build(int o, int l, int r) {
        if (l == r) {
            tree[o] = new node();
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid);
        build(y, mid + 1, r);
        pushUp(o, x, y);
    }

    public void build(int o, int l, int r, final int[] val) {
        if (l == r) {
            tree[o] = new node(val[l]);
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    private void modify(int o, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
             tree[o].apply(l, r, val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
         pushDown(o, x, y, l, r);
        if (left <= mid) modify(x, l, mid, left, right, val);
        if (right > mid) modify(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }

    private void query(int o, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
         pushDown(o, x, y, l, r);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }

    int n;
    node[] tree;

    // 不傳入陣列的預設建樹, 請保證node處進行了預設初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        tree = new node[n << 1];
        build(1, 1, n);
    }

    // 單點修改
    void modify(int index, final int val) {
        // assert (1 <= index && index <= n);
        modify(1, 1, n, index, index, val);
    }

    // 區間查詢
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(1, 1, n, left, right, res);
        return res;
    }
}

Java 動態開點

點選檢視程式碼
class Solution {
    public List<Integer> countSmaller(int[] nums) {
        final int max = (int) 1e4 + 1;
        int n = nums.length;
        // 有負數, 整體都加上 max, 保證數都大於 0
        SegmentTree seg = new SegmentTree(2 * max);
        List<Integer> ans = new ArrayList<Integer>(n);
        // 找右側 有多少個 比 當前數 小 的數
        for (int i = n - 1; i >= 0; --i) {
            int val = nums[i] + max;
            seg.modify(val, 1);
            if (val - 1 == 0) ans.add(0);
            else ans.add(seg.query(1, val - 1).sum);
        }
        Collections.reverse(ans);
        return ans;
    }
}

class SegmentTree {
    class node {
        int sum = 0, add = 0;
        node lChild, rChild;

        void apply(int l, int r, final int val) {
            sum += (r - l + 1) * val;
            add += val;
        }

        public void addNode() {
            if (lChild == null) lChild = new node();
            if (rChild == null) rChild = new node();
        }
    }

    void pushDown(node cur, int l, int r) {
        if (cur.add != 0) {
            int mid = l + r >> 1;
            cur.lChild.apply(l, mid, cur.add);
            cur.rChild.apply(mid + 1, r, cur.add);
            cur.add = 0;
        }
    }

    void unite(node cur, final node son) {
        cur.sum += son.sum;
    }

    void pushUp(node cur) {
        cur.sum = 0;
        unite(cur, cur.lChild);
        unite(cur, cur.rChild);
    }

    private void modify(node cur, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
            cur.apply(l, r, val);
            return;
        }
        cur.addNode();
        pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) modify(cur.lChild, l, mid, left, right, val);
        if (right > mid) modify(cur.rChild, mid + 1, r, left, right, val);
        pushUp(cur);
    }

    private void query(node cur, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, cur);
            return;
        }
        cur.addNode();
        pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) query(cur.lChild, l, mid, left, right, res);
        if (right > mid) query(cur.rChild, mid + 1, r, left, right, res);
        pushUp(cur);
    }

    int n;
    node root;

    // 請保證node處進行了預設初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        root = new node();
    }

    // 單點修改
    public void modify(int index, final int val) {
        // assert (1 <= index && index <= n);
        modify(root, 1, n, index, index, val);
    }

    // 區間查詢
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(root, 1, n, left, right, res);
        return res;
    }
}

307. 區域和檢索 - 陣列可修改 - 力扣

題目連結

題意簡述:有兩個操作

  1. 單點賦值
  2. 區間和查詢

單點賦值可以改為單點查詢+單點修改(查詢這個值再減去這個值)

當然也可以多開一個 \(nums\) 陣列維護單點值

單點修改加區間查詢,懶標記都不需要

再看資料範圍,\(1 <= nums.length <= 3 \times 10^4\),離散化、動態開點也不需要

點選檢視程式碼
class NumArray {
    int n;
    SegmentTree seg;
    int[] nums;
    public NumArray(int[] _nums) {
        n = _nums.length;
        seg = new SegmentTree(_nums, n);
        nums = _nums;
    }
    
    public void update(int index, int val) {
        seg.updateOne(index + 1, val - nums[index], 1, 1, n);
        nums[index] = val;
    }
    
    public int sumRange(int left, int right) {
        return seg.queryRange(left + 1, right + 1, 1, 1, n);
    }
}

class SegmentTree {
    int[] tree;
    int n;

    private void pushUp(int o, int x, int y) {
        tree[o] = tree[x] + tree[y];
    }

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l - 1];
            return;
        }
        int mid = l + r >> 1, x = o << 1, y = x | 1;
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new int[n << 2];
    }

    public SegmentTree(int[] val, int _n) {
        this(_n);
        build(1, 1, n, val);
    }

    public void updateOne(final int index, final int val, int o, int l, int r) {
        if (l == r) {
            tree[o] += val;
            return;
        }
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (index <= mid) updateOne(index, val, x, l, mid);
        else updateOne(index, val, y, mid + 1, r);
        pushUp(o, x, y);
    }

    public int queryRange(final int left, final int right, int o, int l, int r) {
        if (left <= l && r <= right) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        int ans = 0;
        if (left <= mid) ans += queryRange(left, right, x, l, mid);
        if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
        return ans;
    }
}

699. 掉落的方塊 - 力扣

題目連結

題意簡述:

俄羅斯方塊,從上向下降落正方形,有交集則墊高,擦邊的不算有交集,求最高的高度

每個右邊界都\(-1\),解決擦邊的問題

先查詢 \([left,right]\) 區間內的最大值,再將 \([left,right]\) 賦值為該最大值+方塊尺寸

題目變為 區間賦值 和 區間最大值查詢(注意是區間賦值)

Java 動態開點

點選檢視程式碼
class Solution {
    public List<Integer> fallingSquares(int[][] positions) {
        List<Integer> ans = new ArrayList<Integer>(positions.length);
        final int max = ((int) 1e8) + ((int) 1e6);
        SegmentTree seg = new SegmentTree(max);
        for (int[] v : positions) {
            int left = v[0], len = v[1], right = left + len - 1;
            int currentMax = seg.query(left, right).max;
            seg.modify(left, right, currentMax + len);
            ans.add(seg.query(1, max).max);
        }
        return ans;
    }
}

class SegmentTree {
    class node {
        int max = 0, assign = 0;
        node lChild, rChild;

        private void apply(final int val) {
            max = val;
            assign = val;
        }

        private void addNode() {
            if (lChild == null) lChild = new node();
            if (rChild == null) rChild = new node();
        }
    }

    private void pushDown(node cur) {
        if (cur.assign != 0) {
            cur.lChild.apply(cur.assign);
            cur.rChild.apply(cur.assign);
            cur.assign = 0;
        }
    }

    private void unite(node cur, final node son) {
        cur.max = Math.max(cur.max, son.max);
    }

    private void pushUp(node cur) {
        cur.max = 0;
        unite(cur, cur.lChild);
        unite(cur, cur.rChild);
    }

    private void modify(node cur, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
            cur.apply(val);
            return;
        }
        cur.addNode();
        pushDown(cur);
        int mid = l + r >> 1;
        if (left <= mid) modify(cur.lChild, l, mid, left, right, val);
        if (right > mid) modify(cur.rChild, mid + 1, r, left, right, val);
        pushUp(cur);
    }

    private void query(node cur, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, cur);
            return;
        }
        cur.addNode();
        pushDown(cur);
        int mid = l + r >> 1;
        if (left <= mid) query(cur.lChild, l, mid, left, right, res);
        if (right > mid) query(cur.rChild, mid + 1, r, left, right, res);
        pushUp(cur);
    }

    private int n;
    private node root;

    // 請保證node處進行了預設初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        root = new node();
    }

    // 區間修改
    public void modify(int left, int right, final int val) {
        // assert (1 <= left && left <= right && right <= n);
        modify(root, 1, n, left, right, val);
    }

    // 區間查詢
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(root, 1, n, left, right, res);
        return res;
    }
}

Java 離散化

點選檢視程式碼
class Solution {
    int[] rank;
    int cnt;

    int adjacentRemove(int[] a, int n) {
        int slow = 0;
        for (int fast = slow + 1; fast < n; ++fast) {
            if (a[slow] != a[fast] && ++slow != fast) {
                a[slow] = a[fast];
            }
        }
        return slow + 1;
    }

    void discrete(int[][] positions, int n) {
        rank = new int[n << 1];
        for (int i = 0; i < n; ++i) {
            rank[i << 1] = positions[i][0];
            rank[i << 1 | 1] = positions[i][0] + positions[i][1] - 1;
        }
        Arrays.sort(rank);
        cnt = adjacentRemove(rank, n << 1);
    }

    // 查詢對映(大於0)
    int find(int val) {
        return Arrays.binarySearch(rank, 0, cnt, val) + 1;
    }

    public List<Integer> fallingSquares(int[][] positions) {
        int n = positions.length;
        List<Integer> ans = new ArrayList<Integer>(n);
        discrete(positions, n);
        SegmentTree seg = new SegmentTree(cnt);
        for (int[] v : positions) {
            int left = find(v[0]), len = v[1], right = find(v[0] + len - 1);
            int currentMax = seg.query(left, right).max;
            seg.modify(left, right, currentMax + len);
            ans.add(seg.query(1, cnt).max);
        }
        return ans;
    }
}

class SegmentTree {
    class node {
        int max, assign;

        public node(int _max, int _assign) {
            max = _max;
            assign = _assign;
        }

        public node(int _max) {this(_max, 0);}

        public node() {this(0, 0);}

        void apply(final int val) {
            max = val;
            assign = val;
        }
    }

    private void pushDown(int o, int x, int y) {
        if (tree[o].assign != 0) {
            tree[x].apply(tree[o].assign);
            tree[y].apply(tree[o].assign);
            tree[o].assign = 0;
        }
    }

    private void unite(node o, final node son) {
        o.max = Math.max(o.max, son.max);
    }

    private void pushUp(int o, int x, int y) {
        tree[o].assign = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }

    public void build(int o, int l, int r) {
        if (l == r) {
            tree[o] = new node();
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid);
        build(y, mid + 1, r);
        pushUp(o, x, y);
    }

    public void build(int o, int l, int r, final int[] val) {
        if (l == r) {
            tree[o] = new node(val[l]);
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    private void modify(int o, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
            tree[o].apply(val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y);
        if (left <= mid) modify(x, l, mid, left, right, val);
        if (right > mid) modify(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }

    private void query(int o, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }

    private int n;
    private node[] tree;

    // 不傳入陣列的預設建樹, 請保證node處進行了預設初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        tree = new node[n << 1];
        build(1, 1, n);
    }

    // 區間修改
    public void modify(int left, int right, final int val) {
        // assert (1 <= left && left <= right && right <= n);
        modify(1, 1, n, left, right, val);
    }

    // 區間查詢
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(1, 1, n, left, right, res);
        return res;
    }
}

參考資料

線段樹詳解與實現 - 知乎

線段樹詳解 (原理,實現與應用) - AC_King

線段樹從入門到急停 - yukiyama

一維線段樹的2n空間實現

線段樹節點個數的遞推公式與通項公式 - Hoxily

關於線段樹的陣列到底是開2N還是4N - 知乎