動態開點線段樹說明

2022-12-29 06:00:47

動態開點線段樹說明

作者:Grey

原文地址:

部落格園:動態開點線段樹說明

CSDN:動態開點線段樹說明

說明

針對普通線段樹,參考使用線段樹解決陣列任意區間元素修改問題

在普通線段樹中,線段樹在預處理的時候,需要申請 4 倍大小的陣列空間來存放劃分的區域,

而本文介紹的動態開點線段樹,它和普通線段樹的區別是,動態開點線段樹不需要像普通線段樹那樣提前申請 4 倍大小的資料空間來存放劃分割區域,等到實際使用的時候,再來申請。

先講一種比較簡單的動態開點線段樹,這種線段樹只支援單點的更新和查詢。

即支援如下兩個方法

void add(i, v);

該方法表示在 i 上的值加上 v;

int query(int s, int e)

該方法用於獲取 s 到 e 區間內的累加和資訊。

該線段樹只需要定義一個節點資料結構即可

  public static class Node {
    public int sum;
    public Node left;
    public Node right;
  }

其中 sum 表示 Node 所在區間的累加和,left 表示節點左孩子資訊,right 表示節點右孩子資訊。

線段樹初始化過程也只需要

  public static class DynamicSegmentTree {
    public Node root;
    public int size;

    public DynamicSegmentTree(int max) {
      root = new Node();
      size = max;
    }
  }

size 表示線段樹支援的範圍,這個範圍從線段樹一開始初始化的時候設定好(編號1 到 編號size就是區間範圍)。和普通線段樹不一樣的地方在於,節點只建立了 root 節點,未初始化所有區間。

接下來看add方法,

    public void add(int i, int v) {
      add(root, 1, size, i, v);
    }

這個方法呼叫了線段樹內部的私有add方法,

    // c-> cur 當前節點!表達的範圍 l~r
    // i位置的數,增加v
    // 潛臺詞!i一定在l~r範圍上!
    private void add(Node c, int l, int r, int i, int v) {
      if (l == r) {
        c.sum += v;
      } else { // l~r 還可以劃分
        int mid = (l + r) / 2;
        if (i <= mid) { // l ~ mid
          if (c.left == null) {
            c.left = new Node();
          }
          add(c.left, l, mid, i, v);
        } else { // mid + 1 ~ r
          if (c.right == null) {
            c.right = new Node();
          }
          add(c.right, mid + 1, r, i, v);
        }
        c.sum = (c.left != null ? c.left.sum : 0) + (c.right != null ? c.right.sum : 0);
      }
    }

這個add方法的幾個引數分別代表

c : 表示 add 操作的區間代表節點是多少

l...r 表示任務區間,由於初始化 size,所以在呼叫公開的 add 方法時候,l = 1, r = size,表示在初始化區間範圍內操作。

i:表示要操作的位置

v: 表示要增加的值

整個 add 私有方法邏輯也比較簡單,核心程式碼

        // i 在節點左邊
        if (i <= mid) { 
            // 如果節點的左樹為空,則建立新節點
          if (c.left == null) {
            c.left = new Node();
          }
          add(c.left, l, mid, i, v);
        } else { 
            // i 在節點右邊
            // 如果節點右樹為空,則建立新節點
          if (c.right == null) {
            c.right = new Node();
          }
          add(c.right, mid + 1, r, i, v);
        }
        // 最後當前節點要匯聚左右樹的結果,之所以要判空是因為左右樹可能不需要都建立出來
        c.sum = (c.left != null ? c.left.sum : 0) + (c.right != null ? c.right.sum : 0);

查詢方法的邏輯也比較簡單

    public int query(int s, int e) {
      return query(root, 1, size, s, e);
    }

呼叫了內部的一個私有 query 方法,

    private int query(Node c, int l, int r, int s, int e) {
      if (c == null) {
        return 0;
      }
      if (s <= l && r <= e) { 
        return c.sum;
      }
      int mid = (l + r) / 2;
      if (e <= mid) {
        return query(c.left, l, mid, s, e);
      } else if (s > mid) {
        return query(c.right, mid + 1, r, s, e);
      } else {
        return query(c.left, l, mid, s, e) + query(c.right, mid + 1, r, s, e);
      }
    }
  }

這個私有方法的幾個引數說明如下

c:表示要操作的線段樹的代表節點是什麼;

l...r 是劃分的區間範圍

s...e 是任務的區間範圍

核心邏輯如下

// 如果任務的區間已經包含了劃分的區間,直接返回結果
      if (s <= l && r <= e) { 
        return c.sum;
      }
      // 否則,去左右區間拿累加和
      int mid = (l + r) / 2;
      if (e <= mid) {
        return query(c.left, l, mid, s, e);
      } else if (s > mid) {
        return query(c.right, mid + 1, r, s, e);
      } else {
        // 整合成自己的累加和返回
        return query(c.left, l, mid, s, e) + query(c.right, mid + 1, r, s, e);
      }

整個支援單點更新的動態線段樹的完整程式碼如下(含對數器程式碼)

// 只支援單點增加 + 範圍查詢的動態開點線段樹(累加和)
public class Code01_DynamicSegmentTree {

  public static class Node {
    public int sum;
    public Node left;
    public Node right;
  }

  // arr[0] -> 1
  // 線段樹,從1開始下標!
  public static class DynamicSegmentTree {
    public Node root;
    public int size;

    public DynamicSegmentTree(int max) {
      root = new Node();
      size = max;
    }

    // 下標i這個位置的數,增加v
    public void add(int i, int v) {
      add(root, 1, size, i, v);
    }

    // c-> cur 當前節點!表達的範圍 l~r
    // i位置的數,增加v
    // 潛臺詞!i一定在l~r範圍上!
    private void add(Node c, int l, int r, int i, int v) {
      if (l == r) {
        c.sum += v;
      } else { // l~r 還可以劃分
        int mid = (l + r) / 2;
        if (i <= mid) { // l ~ mid
          if (c.left == null) {
            c.left = new Node();
          }
          add(c.left, l, mid, i, v);
        } else { // mid + 1 ~ r
          if (c.right == null) {
            c.right = new Node();
          }
          add(c.right, mid + 1, r, i, v);
        }
        c.sum = (c.left != null ? c.left.sum : 0) + (c.right != null ? c.right.sum : 0);
      }
    }

    // s~e範圍的累加和
    public int query(int s, int e) {
      return query(root, 1, size, s, e);
    }

    // 當前節點c,表達的範圍l~r
    // 收到了一個任務,s~e這個任務!
    // s~e這個任務,影響了多少l~r範圍的數,把答案返回!
    private int query(Node c, int l, int r, int s, int e) {
      if (c == null) {
        return 0;
      }
      if (s <= l && r <= e) {
        return c.sum;
      }
      int mid = (l + r) / 2;
      if (e <= mid) {
        return query(c.left, l, mid, s, e);
      } else if (s > mid) {
        return query(c.right, mid + 1, r, s, e);
      } else {
        return query(c.left, l, mid, s, e) + query(c.right, mid + 1, r, s, e);
      }
    }
  }

  public static class Right {
    public int[] arr;

    public Right(int size) {
      arr = new int[size + 1];
    }

    public void add(int i, int v) {
      arr[i] += v;
    }

    public int query(int s, int e) {
      int sum = 0;
      for (int i = s; i <= e; i++) {
        sum += arr[i];
      }
      return sum;
    }
  }

  public static void main(String[] args) {
    int size = 10000;
    int testTime = 50000;
    int value = 500;
    DynamicSegmentTree dst = new DynamicSegmentTree(size);
    Right right = new Right(size);
    System.out.println("測試開始");
    for (int k = 0; k < testTime; k++) {
      if (Math.random() < 0.5) {
        int i = (int) (Math.random() * size) + 1;
        int v = (int) (Math.random() * value);
        dst.add(i, v);
        right.add(i, v);
      } else {
        int a = (int) (Math.random() * size) + 1;
        int b = (int) (Math.random() * size) + 1;
        int s = Math.min(a, b);
        int e = Math.max(a, b);
        int ans1 = dst.query(s, e);
        int ans2 = right.query(s, e);
        if (ans1 != ans2) {
          System.out.println("出錯了!");
          System.out.println(ans1);
          System.out.println(ans2);
        }
      }
    }
    System.out.println("測試結束");
  }
}

接下來看一個使用動態開點線段樹來解決的一個問題

即:LeetCode 315. Count of Smaller Numbers After Self

注:本題可以用歸併排序,樹狀陣列,有序表來解,也可以用動態開點線段樹來解。

主要思路如下

以如下陣列為例來說明

nums = {5,8,7,4,2,9}

首先,初始化一個 List,這個 List 用於存放每個位置的右側比其小的數有幾個,List 的大小和原始陣列一樣

List<Integer> ans = new ArrayList<>(nums.length);

ans 在初始化的時候,均設定為 0 ,表示,所有位置都還沒計算過。

ans = [0,0,0,0,0,0]

接下來對原始陣列進行排序(注意:排序的時候,不能只使用值來排序,要帶上這個值所在的位置,這樣排序後才不會丟失該值在原始陣列中的位置資訊)

    int[][] arr = new int[n][];
    for (int i = 0; i < n; i++) {
        // 要記錄值,也要記錄位置,防止排序後找不到值對應的位置在哪裡
      arr[i] = new int[] {nums[i], i};
    }
    // 排序按值排序
    Arrays.sort(arr, Comparator.comparingInt(a -> a[0]));

排序後,arr 按如下順序組織

{值:2,原始位置:4}
{值:4,原始位置:3}
{值:5,原始位置:0}
{值:7,原始位置:2}
{值:8,原始位置:1}
{值:9,原始位置:5}

接下來初始化開點線段樹,線段樹的size就是原始陣列的大小,且每個位置都是0,

按順序遍歷這個 arr 陣列,最小值 2 被取出,其原始位置是 4,且 4 號位置右側沒有比自己更小的數,接下來在開點線段樹中把把 4 號位置的值加1,表示 4 號位置被處理過了,線上段樹中查4號位置以後並沒有任何標記記錄,說明沒有比這個數更小的數了,直接設定4號位置的ans值為0

ans = [0,0,0,0,0,0]

線段樹中

seg = [0,0,0,0,1,0]

接下來是 3 號位置的4,線上段樹中查到,有一個比它小的,直接設定到 ans 中,然後線上段樹中把 3 號位置也標記為 1,說明處理過,

ans = [0,0,0,1,0,0]

線段樹中

seg = [0,0,0,1,1,0]

接下來是0號位置的5, 線上段樹中,查到右側有兩個標記過的,說明有兩個比它小的數,直接在 ans 中把 0 號位置設定為 2, 然後線上段樹中把 0 號位置標記為 1 ,說明處理過,此時

ans = [2,0,0,1,0,0]

線段樹中

seg = [1,0,0,1,1,0]

接下來是 2 號位置的 7, 線上段樹中,查到右側有兩個標記過的,說明有兩個比它小的數,直接在 ans 中把 2 號位置設定為 2, 然後線上段樹中把 2 號位置標記為 1 ,說明處理過,此時

ans = [2,0,2,1,0,0]

線段樹中

seg = [1,0,1,1,1,0]

接下來是 1 號位置的 8, 線上段樹中,查到右側有三個標記過的,說明有三個比它小的數,直接在 ans 中把 1 號位置設定為 3, 然後線上段樹中把 1 號位置標記為 1 ,說明處理過,此時

ans = [2,3,2,1,0,0]

線段樹中

seg = [1,1,1,1,1,0]

接下來是 5 號位置的 9, 線上段樹中,查到右側沒有標記過的,說明沒有比它小的數,直接在 ans 中把 5 號位置設定為 0, 然後線上段樹中把 5 號位置標記為 1 ,說明處理過,此時

ans = [2,3,2,1,0,0]

線段樹中

seg = [1,1,1,1,1,1]

以上就是整個流程。

核心程式碼如下

  public static List<Integer> countSmaller(int[] nums) {
    if (nums == null || nums.length == 0) {
      return new ArrayList<>();
    }
    int n = nums.length;
    List<Integer> ans = new ArrayList<>(n);
    for (int i = 0; i < n; i++) {
      ans.add(0);
    }
    int[][] arr = new int[n][];
    for (int i = 0; i < n; i++) {
        // 要記錄值,也要記錄位置,防止排序後找不到值對應的位置在哪裡
      arr[i] = new int[] {nums[i], i};
    }
    Arrays.sort(arr, Comparator.comparingInt(a -> a[0]));
    DynamicSegmentTree dst = new DynamicSegmentTree(n);
    for (int[] num : arr) {
      ans.set(num[1], dst.query(num[1] + 1, n));
      dst.add(num[1] + 1, 1);
    }
    return ans;
  }

其中 DynamicSegmentTree 結構就是前面提到的動態開點線段樹的實現。

更多

演演算法和資料結構筆記