「學習筆記」Garsia-Wachs 演演算法

2023-06-15 06:01:47

前言
本文的資料和圖片均來自 \(\texttt{OI-Wiki}\)

引入

題目描述
在一個操場上擺放著一排 \(N\) 堆石子。現要將石子有次序地合併成一堆。規定每次只能選相鄰的 \(2\) 堆石子合併成新的一堆,並將新的一堆石子數記為該次合併的得分。
試設計一個演演算法,計算出將 \(N\) 堆石子合併成一堆的最小得分。
\((N \leq 40000)\)

過程

我們看到這個題,自然而然會想到區間 DP,即樸素的做法。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

ll r[600], g[600];
ll dp[600][600];

int main() {
	int n;
	scanf("%d", &n);
	for (int i = 1; i <= n; ++ i) {
		scanf("%lld", &r[i]);
		r[i + n] = r[i];
		g[i] = g[i - 1] + r[i];
		dp[i][i] = 0;
	}
	for (int i = n + 1; i <= 2 * n; ++ i) {
		dp[i][i] = 0;
		g[i] = g[i - 1] + r[i];
	}
	for (int l = 1; l < n; ++ l) {
		for (int i = 1, j = i + l; i < n * 2 && j <= n * 2; ++ i, j = i + l) {
			dp[i][j] = 100000000;
			for (int k = i; k < j; ++ k) {
				dp[i][j] = min(dp[i][j], dp[i][k] + dp[k + 1][j] + g[j] - g[i - 1]);
			}
		}
	}
	ll minn = 0x3f3f3f3f;
	for (int i = 1; i <= n; ++ i) {
		minn = min(minn, dp[i][i + n - 1]);
	}
	printf("%lld", minn);
	return 0;
}

交上去後,你會發現,RE 了 \(7\) 個。
為什麼?
因為 \(n\) 太大了,二維陣列開不下,其次就算是用了什麼不為人知的手段開下了這麼大的陣列,\(n^2\) 的複雜度也鐵定超時。
這可怎麼辦呢?
下面介紹一種專門處理石子合併這類問題的演演算法——Garsia-Wachs 演演算法

Garsia-Wachs 演演算法

Garsia-Wachs 的步驟如下:
在序列的兩端設定極大值。
在序列中找到前三個連續的權重值 \(x, y, z\) 使得 \(x \leq z\)。因為序列結尾的最大值大於之前的任意兩個有限值,所以總是存在這樣的三元組。
從序列中移除 \(x\)\(y\),並在原來 \(x\) 的位置以前大於或等於 \(x+y\) 且距 \(x\) 最近的值的右邊重新插入元素,元素值為 \(x+y\)。因為左端最大值的存在,所以總是存在這樣的位置。
為了有效地實現這一階段,該演演算法可以在任何平衡二叉查詢樹結構中維護當前值序列。這樣的結構允許我們在對數時間內移除 \(x\)\(y\),並重新插入新節點 \(x + y\)
在每一步中,陣列中位於偶數索引上直到 \(y\) 值的權重形成了一個遞減序列,位於奇數索引位的權重形成另一個遞減序列。因此,重新插入 \(x+y\) 的位置可以通過在對數時間內對這兩個遞減序列使用平衡樹執行兩次二分查詢找到。通過從前一個三元組 \(z\) 值開始的線性順序搜尋,我們可以在匯流排性時間複雜度內執行對滿足 \(x \leq z\) 的第一個位置的搜尋。
如果實在不會平衡樹,vectorinserterase 操作也是個不錯的選擇呢!
Garsia-Wachs 演演算法的總時間複雜度為 \(O(n\log n)\),時間複雜度證明?我只能說,學 OI 記住結論就好了,證明,那是數學要考慮的事,不是 OI 要考慮的事 考試又不會讓你證明時間複雜度
至於正確性的證明我也不會= =,這個演演算法應用範圍十分有限,因此學的價值不是很高,「會用」 + 「知道有這個東西」 就行了
關於上面那道引入題的程式碼:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

inline ll read() {
	ll x = 0;
	int fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int N = 4e4 + 5;

int n, ans;
vector<int> g;

int merge() {
	int k = g.size() - 2;
	for (int i = 0; i <= k; ++ i) {
		if (g[i] <= g[i + 2]) {
			k = i;
			break;
		}
	}
	int tmp = g[k] + g[k + 1];
	g.erase(g.begin() + k);
	g.erase(g.begin() + k);
	int t = -1;
	for (int i = k - 1; i >= 0; -- i) {
		if (g[i] >= tmp) {
			t = i;
			break;
		}
	}
	g.insert(g.begin() + t + 1, tmp);
	return tmp;
}

int main() {
	n = read();
	for (int i = 1; i <= n; ++ i) {
		g.emplace_back(read());
	}
	for (int i = 1; i < n; ++ i) {
		ans += merge();
	}
	printf("%d\n", ans);
	return 0;
}