淺談斜率優化DP

2023-11-13 18:01:31

前言

考試 T2 出題人放了個樹上斜率優化 DP,直接被同校 OIER 吊起來錘。

離 NOIP 還有不到一週,趕緊學一點。

引入

斜率

斜率,數學、幾何學名詞,是表示一條直線(或曲線的切線)關於(橫)座標軸傾斜程度的量。它通常用直線(或曲線的切線)與(橫)座標軸夾角的正切,或兩點的縱座標之差與橫座標之差的比來表示。

斜率可以用來描述一個坡的傾斜程度,公式 \(k = \frac{\Delta y}{\Delta x}\)

初中學過一元一次函數 \(y = kx + b\),這裡的 \(k\) 就是這個函數表示的直線的斜率。

解決什麼

一般對於形如 \(f[i] = \min(a[i] \times b[j] + c[i] + d[j])\) 這種型別的 DP 轉移式子都可以用上斜率優化。

其中 \(b\) 要滿足單調遞增。

看到中間有一部分與 \(i,j\) 都有關,所以這個時候要用到斜率優化。

理解

下面來以一道題目為例進行講解。

P3195 [HNOI2008] 玩具裝箱

看完題目應該都可以想出來一個 \(O(n^2)\) 的 DP,那就是:

\(f[i]\) 表示考慮到第 \(i\) 個玩具所用的最小花費,\(sum[i]\) 為從 \(1\sim i\) 的玩具長度總和。

\[f[i] = \min\{f[j] + (sum[i] - sum[j] + i - j - L - 1)^2\} \]

我們嘗試把這一堆東西分分類,把只有 \(i\) 的挪到一起,只有 \(j\) 的挪到一起,剩下的挪到中間。

得到:

\[f[i] = \min\{f[j] + (sum[i] + i - sum[j] - j - L - 1)^2\} \]

\(A=sum[i] + i, B = sum[j] - j - L - 1\)

那麼就是 :

\[f[i] = f[j] + A^2 -2AB + B^2 \]

顯然的,\(A^2\) 我們可以預處理,是已知的,由於字首和,而且玩具長度至少為 \(1\),所以 \(2A\) 是嚴格單調遞增的,\(B\) 陣列我們也可以直接預處理。

\[f[j] + B^2 = 2AB + f[j] + A^2 \]

這個式子是把只與 \(j\) 有關的移到左邊了,可以發現形式上是和 \(y = kx + b\) 一樣的。

那麼我們就可以把一個之前轉移完成的狀態看成是一個 \((B, f[j] + B^2)\) 的點,而 \(2A\) 就是經過他們的直線的斜率。

那麼我們要求 \(f[i]\) 的話,就是求這個點和這個斜率為 \(2A\) 的直線的最大可能截距是多少。

於影象中

假設下面的三個點是我們待選的狀態:

假設我們當前要求的斜率畫出來是下面這樣:

我們就從下往上,一點一點向上挪,直到碰到的第一個點,此時的截距一定最大。我們也能看出的確 \(C\) 點最優。

那麼此時的 \(A\) 點好像沒有什麼用了,可以扔掉嗎?

答案是可以,因為斜率是單調遞增的,既然這次第一個碰不到 \(A\),那麼後面肯定也不是第一個碰到。

但是我們如何做到最快找出呢?

佇列維護

觀察這張圖片,假設裡面的點都是之前轉移完的狀態。

比較 \(AE,AB\) 的斜率。

不難發現 \(AB\) 的斜率比 \(AE\) 小,想一下之前說的,如果拿一條直線去碰這個圖形,從各個角度去碰,最外層的點會形成一個凸包,而這個凸包內的點,是無論如何都碰不到的。

這個我們可以用一個佇列來維護一個下凸殼,也就是凸包的一部分。

然後根據上面說的,要是佇列頭的兩個元素形成的直線斜率比當前的小,也可以直接彈出。

這樣佇列的隊頭元素就是我們要轉移的值了。

code:


/*
 * @Author: Aisaka_Taiga
 * @Date: 2023-11-13 14:11:27
 * @LastEditTime: 2023-11-13 15:09:40
 * @LastEditors: Aisaka_Taiga
 * @FilePath: \Desktop\P3195.cpp
 * The heart is higher than the sky, and life is thinner than paper.
 */
#include <bits/stdc++.h>

#define pf(x) ((x) * (x))
#define int long long
#define DB double
#define N 1000100

using namespace std;

inline int read()
{
    int x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){if(c == '-') f = -1; c = getchar();}
    while(c <= '9' && c >= '0') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
    return x * f;
}

int n, L, q[N], c[N], f[N], sum[N], A[N], B[N];

inline int X(int x){return B[x];}

inline int Y(int x){return f[x] + pf(B[x]);}

inline DB xl(int i, int j){return (Y(i) - Y(j)) * 1.0 / (X(i) - X(j));}

signed main()
{
    n = read(), L = read();
    for(int i = 1; i <= n; i ++) c[i] = read();
    for(int i = 1; i <= n; i ++)
    {
        sum[i] = sum[i - 1] + c[i];
        B[i] = sum[i] + i + L + 1;
        A[i] = sum[i] + i;
    }
    B[0] = L + 1;
    int h = 1, t = 1;
    for(int i = 1; i <= n; i ++)
    {
        while(h < t && xl(q[h], q[h + 1]) < 2 * A[i]) h ++;
        int j = q[h];
        f[i] = f[j] + pf(A[i] - B[j]);
        while(h < t && xl(q[t - 1], i) < xl(q[t - 1], q[t])) t --;
        q[++ t] = i;
    }
    cout << f[n] << endl;
    return 0;
}

參考:https://www.cnblogs.com/terribleterrible/p/9669614.html