作者:Grey
原文地址: 根據先序遍歷和中序遍歷生成後序遍歷
假設有一棵二元樹
先序遍歷的結果是
中序遍歷的結果是
由於先序遍歷大的排程邏輯是,先頭,再左,再右
後序遍歷的排程邏輯是:先左,再右,再頭。
所以:後序遍歷的最後一個節點,一定是先序遍歷的頭節點。
定義遞迴函數
// 先序遍歷陣列pre的[l1...r1]區間
// 中序遍歷陣列in的[l2...r2]區間
// 生成後序遍歷陣列pos的[l3...r3]區間
void func(int[] pre, int l1, int r1, int[] in, int l2, int r2, int[] pos, int l3, r3)
依據以上推斷,可以得到如下結論
// 後序遍歷的最後一個節點,一定是先序遍歷的頭節點
pos[r3] = pre[l1];
然後,在中序陣列中,我們可以定位到這個頭節點的位置,即下圖中標黃的位置,假設這個位置是index
,
這個index
將中序陣列分成了左右兩個部分,由於中序遍歷的排程過程是:先左,再頭,再右,所以在中序遍歷中[l2......index]
區間內,是以index
位置為頭的左樹中序遍歷結果,[l2......index]
區間內元素個數假設為b
,那麼在先序遍歷中,從頭往後數b
個元素,即:[l1......l1+b]
構成了以index
位置為頭的左樹的先序遍歷結果。
public static void func(int[] pre, int l1, int r1, int[] in, int l2, int r2, int[] pos, int l3, int r3) {
if (l1 > r1) {
// 避免了無效情況
return;
}
if (l1 == r1) {
// 只有一個數的時候
pos[l3] = pre[l1];
} else {
// 不止一個數的時候
pos[r3] = pre[l1];
// index表示某個頭在中序陣列中的位置
int index;
for (index = l2; index <= r2; index++) {
if (in[index] == pre[l1]) {
break;
}
}
int b = index - l2;
// 構造左樹
func(pre, l1 + 1, l1 + b, in, l2, index - 1, pos, l3, l3 + b - 1);
// 構造右樹
func(pre, l1 + b + 1, r1, in, index + 1, r2, pos, l3 + b, r3 - 1);
}
}
在遞迴函數func
中,有一個遍歷的行為,
for (index = l2; index <= r2; index++) {
if (in[index] == pre[l1]) {
break;
}
}
如果每次遞迴都要遍歷一下,那麼效率會降低,所以可以在一開始就設定一個map
,存一下中序遍歷中每個值所在的位置資訊,這樣就不需要通過遍歷來找位置了,方法如下:
Map<Integer, Integer> map = new HashMap<>();
for (int i = 0; i < n; i++) {
inOrder[i] = in.nextInt();
map.put(inOrder[i], i);
}
這樣預處理以後,每次index
的位置不需要遍歷得到,只需要
int index = map.get(pre[l1]);
即可,完整程式碼見
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int n = in.nextInt();
int[] preOrder = new int[n];
int[] inOrder = new int[n];
for (int i = 0; i < n; i++) {
preOrder[i] = in.nextInt();
}
Map<Integer, Integer> map = new HashMap<>();
for (int i = 0; i < n; i++) {
inOrder[i] = in.nextInt();
map.put(inOrder[i], i);
}
int[] posOrder = new int[n];
func(preOrder, 0, n - 1, inOrder, 0, n - 1, posOrder, 0, n - 1, map);
for (int i = 0; i < n; i++) {
System.out.print(posOrder[i] + " ");
}
in.close();
}
public static void func(int[] pre, int l1, int r1, int[] in, int l2, int r2, int[] pos, int l3, int r3,
Map<Integer, Integer> map) {
if (l1 > r1) {
// 避免了無效情況
return;
}
if (l1 == r1) {
// 只有一個數的時候
pos[l3] = pre[l1];
} else {
// 不止一個數的時候
pos[r3] = pre[l1];
// index表示某個頭在中序陣列中的位置
int index = map.get(pre[l1]);
int b = index - l2;
func(pre, l1 + 1, l1 + b, in, l2, index - 1, pos, l3, l3 + b - 1, map);
func(pre, l1 + b + 1, r1, in, index + 1, r2, pos, l3 + b, r3 - 1, map);
}
}
}