「學習筆記」記憶化搜尋

2023-06-13 18:00:45

由於我一直對搜尋情有獨鍾,因此,如果能寫記憶化搜尋的絕不會寫 for 迴圈 DP。
文章部分內容來自 \(\texttt{OI-Wiki}\)

引入

記憶化搜尋是一種通過記錄已經遍歷過的狀態的資訊,從而避免對同一狀態重複遍歷的搜尋實現方式。
因為記憶化搜尋確保了每個狀態只存取一次,它也是一種常見的動態規劃實現方式。

我們通過下面一道題來引入。

P1434 [SHOI2002] 滑雪
有一個 \(R \times C\) 的二維矩陣,可以從某個點到達上下左右相鄰四個點之一,當且僅當高度會減小,即這是一個下降序列。求這個下降序列長度的最大值。

我們的樸素 DFS 做法:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

inline ll read() {
	ll x = 0;
	int fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int fx[4] = {0, 1, 0, -1};
const int fy[4] = {1, 0, -1, 0};

int r, c, maxn;
int a[110][110], f[110][110];

int dfs(int x, int y) {
	int xx, yy, ma = 0;
	for (int i = 0; i <= 3; ++i) {
		xx = x + fx[i];
		yy = y + fy[i];
		if (xx > 0 && xx <= r && yy > 0 && yy <= c && a[xx][yy] < a[x][y]) {
			ma = max(dfs(xx, yy), ma);
		}
	}
	return ma + 1;
}

int main() {
	r = read(), c = read();
	for (int i = 1; i <= r; ++ i)
		for (int j = 1; j <= c; ++ j) {
			a[i][j] = read();
		}
	for (int i = 1; i <= r; ++ i)
		for (int j = 1; j <= c; ++ j) {
			maxn = max(maxn, dfs(i, j));
		}
	printf("%d", maxn);
	return 0;
}

交上去一看,T 了一個點。
為什麼 T 了呢?
我們假設 \((i, j)\) 這個點當前被搜到,繼續搜,得到最大值,返回了。
後來,又一次搜到了 \((i, j)\) 這個點,然後又重新搜了一遍;再後來,又搜到了這個點,又重新搜了一遍......
因此,導致我們的這份程式碼跑得慢的原因就是多次進行同一個操作,搜尋同一個變數。
為了提升速度,防止重複搜一種情況,我們設定記憶化陣列來儲存我們的值,同時阻止他繼續重複搜尋。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

inline ll read() {
	ll x = 0;
	int fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int fx[4] = {0, 1, 0, -1};
const int fy[4] = {1, 0, -1, 0};

int r, c, maxn;
int a[110][110], f[110][110];

int dfs(int x, int y) {
	if (f[x][y])	return f[x][y];
	f[x][y] = 1;
	int xx, yy, ma = 0;
	for (int i = 0; i <= 3; ++i) {
		xx = x + fx[i];
		yy = y + fy[i];
		if (xx > 0 && xx <= r && yy > 0 && yy <= c && a[xx][yy] < a[x][y]) {
			ma = max(dfs(xx, yy), ma);
		}
	}
	f[x][y] += ma;
	return f[x][y];
}

int main() {
	r = read(), c = read();
	for (int i = 1; i <= r; ++ i) {
		for (int j = 1; j <= c; ++ j) {
			a[i][j] = read();
		}
	}	
	for (int i = 1; i <= r; ++ i) {
		for (int j = 1; j <= c; ++ j) {
			maxn = max(maxn, dfs(i, j));
		}
	}
	printf("%d\n", maxn);
	return 0;
}

然後,你就可以愉快的 AC 了!
由此你也發現了,記憶化搜尋相較於一般搜尋速度快是因為避免了對同一狀態的重複遍歷。

寫記憶化搜尋方法

方法一

  1. 把這道題的 dp 狀態和方程寫出來
  2. 根據它們寫出 dfs 函數
  3. 新增記憶化陣列

方法二

  1. 寫出這道題的暴搜程式(最好是 dfs)
  2. 將這個 dfs 改成無需外部變數的 dfs
  3. 新增記憶化陣列

與遞推的區別

記憶化搜尋和遞推,都確保了同一狀態至多隻被求解一次。而它們實現這一點的方式則略有不同:遞推通過設定明確的存取順序來避免重複存取,記憶化搜尋雖然沒有明確規定存取順序,但通過給已經存取過的狀態打標記的方式,也達到了同樣的目的。
與遞推相比,記憶化搜尋因為不用明確規定存取順序,在實現難度上有時低於遞推,且能比較方便地處理邊界情況,這是記憶化搜尋的一大優勢。但與此同時,記憶化搜尋難以使用捲動陣列等優化,且由於存在遞迴,執行效率會低於遞推。因此應該視題目選擇更適合的實現方式。

題目

P1220 關路燈
在一條路線上安裝了 \(n\) 盞路燈,每盞燈的功率有大有小(即同一段時間內消耗的電量有多有少)。老張就住在這條路中間某一路燈旁,他有一項工作就是每天早上天亮時一盞一盞地關掉這些路燈。他每天都是在天亮時首先關掉自己所處位置的路燈,然後可以向左也可以向右去關燈。現在已知老張走的速度為 \(1m/s\),每個路燈的位置(是一個整數,即距路線起點的距離,單位:\(m\))、功率(\(W\)),老張關燈所用的時間很短而可以忽略不計。問:怎樣最省電?\(n \le 50\)

記憶化搜尋程式碼:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N = 60;

int n, c;
int s[N], w[N], sum[N];
int dp[N][N][2];

int dfs(int l, int r, int las) {
	if (~dp[l][r][las]) {
		return dp[l][r][las];
	}
	if (l == 1 && r == n) {
		return dp[l][r][las] = 0;
	}
	int minn = 1e9 + 5;
	if (las == 0) {
		if (l != 1 && r != n) {
			minn = min(dfs(l - 1, r, las) + (sum[n] - sum[r] + sum[l - 1])
				* (s[l] - s[l - 1]), dfs(l, r + 1, las ^ 1) + (sum[n] - 
					sum[r] + sum[l - 1]) * (s[r + 1] - s[l]));
		}
		else if (l != 1 && r == n) {
			minn = min(minn, dfs(l - 1, r, las) + (sum[n] - sum[r] + 
				sum[l - 1]) * (s[l] - s[l - 1]));
		}
		else if (l == 1 && r != n) {
			minn = min(minn, dfs(l, r + 1, las ^ 1) + (sum[n] - sum[r]
				+ sum[l - 1]) * (s[r + 1] - s[l]));
		}
	}
	else {
		if (l != 1 && r != n) {
			minn = min(dfs(l - 1, r, las ^ 1) + (sum[n] - sum[r] + sum[l - 1])
				* (s[r] - s[l - 1]), dfs(l, r + 1, las) + (sum[n] - 
					sum[r] + sum[l - 1]) * (s[r + 1] - s[r]));
		}
		else if (l != 1 && r == n) {
			minn = min(minn, dfs(l - 1, r, las ^ 1) + (sum[n] - sum[r] + 
				sum[l - 1]) * (s[r] - s[l - 1]));
		}
		else if (l == 1 && r != n) {
			minn = min(minn, dfs(l, r + 1, las) + (sum[n] - sum[r]
				+ sum[l - 1]) * (s[r + 1] - s[r]));
		}
	}
	return (dp[l][r][las] = minn);
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);
	memset(dp, -1, sizeof dp);
	cin >> n >> c;
	for (int i = 1; i <= n; ++ i) {
		cin >> s[i] >> w[i];
		sum[i] = sum[i - 1] + w[i];
	}
	cout << min(dfs(c, c, 0), dfs(c, c, 1)) << '\n';
	return 0;
}

P3607 [USACO17JAN]Subsequence Reversal P
FJ 正在安排他的 \(N\) 頭奶牛排成一隊以拍照 \((1 \le n \le 50)\)。佇列中的第i頭奶牛的身高為 \(a_i\),並且 FJ 認為如果奶牛的身高序列中含有一個很長的不下降子序列的話,這會是一張很好的照片。
回憶一下,子序列是由牛序列中的一些元素 \(a_{i_1},a_{i_2},.....a_{i_k}\) 組成的子集。\((i_1<i_2< \cdots <i_k)\) 如果 \(a_{i_1} \le a_{i_2} \le a_{i_3} \le \cdots \le a_{i_k}\) 的話,我們就說這個序列是不下降的。
FJ 想要在他的奶牛序列中包括一個長期增長的子序列(也就是很長的不下降子序列)。為了確保這一點,他允許自己在一開始選擇任何子序列並逆轉其元素。
觀察這個子序列(上方英文)是如何反轉並佔據他們最初的相同的位置的,且其他元素保持不變。
在只能反轉一次任意子序列的情況下,請找到不下降子序列的最大可能長度。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N = 110;

int n;
int a[N], dp[60][60][60][60];
bool vis[60][60][60][60];

int dfs(int l, int r, int d, int u) {
	if (vis[l][r][d][u]) {
		return dp[l][r][d][u];
	}
	if (d > u)	return -1e8;
	if (l > r)	return 0;
	if (l == r) {
		if (d <= a[l] && a[r] <= u)	return dp[l][r][d][u] = 1;
		else	return 0;
	}
	dp[l][r][d][u] = max(dp[l][r][d][u], dfs(l + 1, r - 1, d, u));
	if (a[r] >= d) {
		dp[l][r][d][u] = max(dp[l][r][d][u], dfs(l + 1, r - 1, a[r], u) + 1);
	}
	if (a[l] <= u) {
		dp[l][r][d][u] = max(dp[l][r][d][u], dfs(l + 1, r - 1, d, a[l]) + 1);
	}
	if (a[l] <= u && a[r] >= d) {
		dp[l][r][d][u] = max(dp[l][r][d][u], dfs(l + 1, r - 1, a[r], a[l]) + 2);
	}
	dp[l][r][d][u] = max(dfs(l + 1, r, d, u), dp[l][r][d][u]);
	if (a[l] >= d) {
		dp[l][r][d][u] = max(dp[l][r][d][u], dfs(l + 1, r, a[l], u) + 1);
	}
	dp[l][r][d][u] = max(dfs(l, r - 1, d, u), dp[l][r][d][u]);
	if (a[r] <= u) {
		dp[l][r][d][u] = max(dp[l][r][d][u], dfs(l, r - 1, d, a[r]) + 1);
	}
	vis[l][r][d][u] = 1;
	return dp[l][r][d][u];
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);
	cin >> n;
	for (int i = 1; i <= n; ++ i) {
		cin >> a[i];
	}
	cout << dfs(1, n, 0, 50) << '\n';
	return 0;
}