【資料結構】吉司機線段樹

2023-05-31 21:00:38

【資料結構】吉司機線段樹(Segment Tree Beats)

吉司機線段樹,是由杭州學軍中學的吉如一在2016年國集論文當中提出的,解決了區間最值操作和區間歷史最值問題。

題目描述

給出一個長度為 \(n\) 的數列 \(A\),同時定義一個輔助陣列 \(B\)\(B\) 開始與 \(A\) 完全相同。接下來進行了 \(m\) 次操作,操作有五種型別,按以下格式給出:

  • 1 l r k:對於所有的 \(i\in[l,r]\),將 \(A_i\) 加上 \(k\)\(k\) 可以為負數)。
  • 2 l r v:對於所有的 \(i\in[l,r]\),將 \(A_i\) 變成 \(\min(A_i,v)\)
  • 3 l r:求 \(\sum_{i=l}^{r}A_i\)
  • 4 l r:對於所有的 \(i\in[l,r]\),求 \(A_i\) 的最大值。
  • 5 l r:對於所有的 \(i\in[l,r]\),求 \(B_i\) 的最大值。

在每一次操作後,我們都進行一次更新,讓 \(B_i\gets\max(B_i,A_i)\)

演演算法描述

對於1、3、4操作,是最基本的線段樹區間加,區間求和、求max的操作,詳見P3372 【模板】線段樹 1 - 洛谷 | 電腦科學教育新生態 (luogu.com.cn),我們先來解決操作2,我們發現想要動態探測哪些值大於\(v\),並進行min操作十分艱難,因為每個大於\(v\)的數不一樣,所以每個大於\(v\)的數要減去的數就不一樣,很難統計。

那就讓區間內大於\(v\)的數只有一個好了。

對於一個區間,我們記錄一個最大值\(maxn\),和一個嚴格次大值\(sec\)(只有一個值時為\(-inf\)),對於區間取\(min\)操作,我們分為以下三種情況討論:

1.\(k \geq maxn\)\(k\)在這個區間之內不能更新任何一個值,直接\(return\)

2.\(maxn > k > sec\)\(k\)只能更新\(maxn\),讓\(sum -= maxn * cnt\)(區間最大值個數)\(,maxn = k\),因為要下傳操作,所以讓\(tag2\)(更新\(maxn\)\(tag\))減去(\(maxn - k\))

3.\(k \leq sec\),分別向\(lc\)\(rc\)兩邊遞迴更新答案。

這樣我們就完成了對最小值的更新,同時向兩邊遞迴,會不會影響複雜度?事實上這個操作的複雜度仍然是\(O(logn)\)的,證明如下:
(選自吉如一2016國家集訓隊論文)

程式碼:

inline void modify_min(int l,int r,int L,int R,int k,int pos)
{
	if(k >= t[pos].maxn) return;
	if(L <= l && r <= R && k > t[pos].sec)
	{
		t[pos].sum -= 1ll * t[pos].cnt * (t[pos].maxn - k);
		t[pos].tag2 -= t[pos].maxn - k;
		t[pos].maxn = k;
		return;
	}
	pushdown(pos,l,r);
	int mid = (l + r) >> 1;
	if(L <= mid) modify_min(l,mid,L,R,k,pos << 1);
	if(R > mid) modify_min(mid + 1,r,L,R,k,pos << 1 | 1);
	pushup(pos);
}

現在來看操作5,我們發現,因為修改只有加,一個位置的歷史最大值,其實是這個位置原來的值(不一定是原始值,可以看作是pushdown之前的值)加上pushdown之前出現過最大的tag,我們發現在上面的\(modify\_min\)操作中最大值會單獨改變,所以區間中的最大值與其他數改變的量(也就是tag)是不一樣的,所以我們記錄\(tag1\),\(tag2\)代表其他數,最大值的改變數,用\(tag3\),\(tag4\)表示\(tag1\),\(tag2\)在pushdown之前的最大值,這樣我們在pushdown的時候就有:

\[history\_max_{pos} = max\{history\_max_{pos},max_{pos} + tag4\} \]

每次更新\(k1\)\(k3\)的時候,都相應的更新\(k2\)\(k4\)

其實最難的部分在更新標記上,這裡我們先講上傳:

和、最大值、歷史最大值都可以左右區間直接合並更新,但是對於次大值和最大值個數,我們分類討論:

1.\(maxn_{lc} == maxn_{rc}\) : 這個時候次大值應當等於兩兒子的次大值取\(max\),而最大值計數等於兩邊相加。

2.\(maxn_{lc} > maxn_{rc}\) : 次大值等於左邊的次大值和右邊的最大值取\(max\),最大值計數等於左邊的\(cnt\)

3.\(maxn_{lc} < maxn_{rc}\) : 次大值等於左邊的最大值和右邊的次大值取\(max\),最大值計數等於右邊的\(cnt\)

這樣就完成了上傳

inline void pushup(int pos)
{
	t[pos].sum = t[pos << 1].sum + t[pos << 1 | 1].sum;
	t[pos].maxn = max(t[pos << 1].maxn,t[pos << 1 | 1].maxn);
	t[pos].hismax = max(t[pos << 1].hismax,t[pos << 1 | 1].hismax);
	if(t[pos << 1].maxn == t[pos << 1 | 1].maxn) 
		t[pos].sec = max(t[pos << 1].sec,t[pos << 1 | 1].sec),t[pos].cnt = t[pos << 1].cnt + t[pos << 1 | 1].cnt;
	else if(t[pos << 1].maxn > t[pos << 1 | 1].maxn)
		t[pos].sec = max(t[pos << 1].sec,t[pos << 1 | 1].maxn),t[pos].cnt = t[pos << 1].cnt;
	else 
		t[pos].sec = max(t[pos << 1].maxn,t[pos << 1 | 1].sec),t[pos].cnt = t[pos << 1 | 1].cnt;
}

對於下傳,我們也需要分類討論:
如果全域性最大值在左邊,那麼左邊的最大值要按照最大值的方式來更新(即將\(tag2\)\(tag4\)傳下去),否則就將左邊不管是不是左邊的最大值,都用\(tag1\)\(tag3\)來更新。

如果全域性最大值在右邊,同理。

注意這兩個條件可以同時成立,不要寫else

inline void pushdown(int pos,int l,int r)
{
	int mid = (l + r) >> 1;
	int mx = max(t[pos << 1].maxn,t[pos << 1 | 1].maxn);
	if(mx == t[pos << 1].maxn) change(pos << 1,l,mid,t[pos].tag1,t[pos].tag2,t[pos].tag3,t[pos].tag4);
	else change(pos << 1,l,mid,t[pos].tag1,t[pos].tag1,t[pos].tag3,t[pos].tag3);
	if(mx == t[pos << 1 | 1].maxn) change(pos << 1 | 1,mid + 1,r,t[pos].tag1,t[pos].tag2,t[pos].tag3,t[pos].tag4);
	else change(pos << 1 | 1,mid + 1,r,t[pos].tag1,t[pos].tag1,t[pos].tag3,t[pos].tag3);
	t[pos].tag1 = 0;
	t[pos].tag2 = 0;
	t[pos].tag3 = 0;
	t[pos].tag4 = 0;
}
inline void change(int pos,int l,int r,int k1,int k2,int k3,int k4)
{
	t[pos].sum += 1ll * (r - l + 1 - t[pos].cnt) * k1 + 1ll * t[pos].cnt * k2;
	t[pos].hismax = max(t[pos].hismax,t[pos].maxn + k4);
	t[pos].maxn += k2;
	if(t[pos].sec != -inf) t[pos].sec += k1;
	t[pos].tag4 = max(t[pos].tag4,t[pos].tag2 + k4);
	t[pos].tag3 = max(t[pos].tag3,t[pos].tag1 + k3);
	t[pos].tag1 += k1;
	t[pos].tag2 += k2;
}

注意\(sec\)一行,如果這個節點沒有次大值(就是整個區間只有一個值),那麼就不能更新值為\(-inf\)的次大值。

區間加時注意要更新全部變數:

inline void modify_add(int l,int r,int L,int R,int k,int pos)
{ 
	if(L <= l && r <= R)
	{
		t[pos].sum += 1ll * k * (r - l + 1 - t[pos].cnt) + 1ll * k * t[pos].cnt;
		t[pos].maxn += k;
		t[pos].hismax = max(t[pos].hismax,t[pos].maxn);
		if(t[pos].sec != -inf) t[pos].sec += k;
		t[pos].tag1 += k;t[pos].tag2 += k;
		t[pos].tag3 = max(t[pos].tag3,t[pos].tag1);
		t[pos].tag4 = max(t[pos].tag4,t[pos].tag2); 
		return;
	}
	pushdown(pos,l,r);
	int mid = (l + r) >> 1;
	if(L <= mid) modify_add(l,mid,L,R,k,pos << 1);
	if(R > mid) modify_add(mid + 1,r,L,R,k,pos << 1 | 1);
	pushup(pos);
}

對最大值,和,歷史最大值的查詢正常查詢就好,注意在建樹的時候給次大值附上\(-inf\)的初值

Code

#include<bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5,inf = 2e9;
struct Node{
	long long sum;
	int maxn,sec,cnt,hismax,tag1,tag2,tag3,tag4;
}t[N * 4];
int n,m;
inline void pushup(int pos)
{
	t[pos].sum = t[pos << 1].sum + t[pos << 1 | 1].sum;
	t[pos].maxn = max(t[pos << 1].maxn,t[pos << 1 | 1].maxn);
	t[pos].hismax = max(t[pos << 1].hismax,t[pos << 1 | 1].hismax);
	if(t[pos << 1].maxn == t[pos << 1 | 1].maxn) 
		t[pos].sec = max(t[pos << 1].sec,t[pos << 1 | 1].sec),t[pos].cnt = t[pos << 1].cnt + t[pos << 1 | 1].cnt;
	else if(t[pos << 1].maxn > t[pos << 1 | 1].maxn)
		t[pos].sec = max(t[pos << 1].sec,t[pos << 1 | 1].maxn),t[pos].cnt = t[pos << 1].cnt;
	else 
		t[pos].sec = max(t[pos << 1].maxn,t[pos << 1 | 1].sec),t[pos].cnt = t[pos << 1 | 1].cnt;
}
inline void change(int pos,int l,int r,int k1,int k2,int k3,int k4)
{
	t[pos].sum += 1ll * (r - l + 1 - t[pos].cnt) * k1 + 1ll * t[pos].cnt * k2;
	t[pos].hismax = max(t[pos].hismax,t[pos].maxn + k4);
	t[pos].maxn += k2;
	if(t[pos].sec != -inf) t[pos].sec += k1;
	t[pos].tag4 = max(t[pos].tag4,t[pos].tag2 + k4);
	t[pos].tag3 = max(t[pos].tag3,t[pos].tag1 + k3);
	t[pos].tag1 += k1;
	t[pos].tag2 += k2;
}
inline void pushdown(int pos,int l,int r)
{
	int mid = (l + r) >> 1;
	int mx = max(t[pos << 1].maxn,t[pos << 1 | 1].maxn);
	if(mx == t[pos << 1].maxn) change(pos << 1,l,mid,t[pos].tag1,t[pos].tag2,t[pos].tag3,t[pos].tag4);
	else change(pos << 1,l,mid,t[pos].tag1,t[pos].tag1,t[pos].tag3,t[pos].tag3);
	if(mx == t[pos << 1 | 1].maxn) change(pos << 1 | 1,mid + 1,r,t[pos].tag1,t[pos].tag2,t[pos].tag3,t[pos].tag4);
	else change(pos << 1 | 1,mid + 1,r,t[pos].tag1,t[pos].tag1,t[pos].tag3,t[pos].tag3);
	t[pos].tag1 = 0;
	t[pos].tag2 = 0;
	t[pos].tag3 = 0;
	t[pos].tag4 = 0;
}
inline void build(int l,int r,int pos)
{
	if(l == r)
	{
		cin>>t[pos].sum;
		t[pos].maxn = t[pos].sum;
		t[pos].hismax = t[pos].maxn;
		t[pos].sec = -inf;
		t[pos].tag1 = t[pos].tag2 = t[pos].tag3 = t[pos].tag4 = 0;
		t[pos].cnt = 1;
		return;
	}
	int mid = (l + r) >> 1;
	build(l,mid,pos << 1);
	build(mid + 1,r,pos << 1 | 1);
	pushup(pos);
}
inline void modify_add(int l,int r,int L,int R,int k,int pos)
{ 
	if(L <= l && r <= R)
	{
		t[pos].sum += 1ll * k * (r - l + 1 - t[pos].cnt) + 1ll * k * t[pos].cnt;
		t[pos].maxn += k;
		t[pos].hismax = max(t[pos].hismax,t[pos].maxn);
		if(t[pos].sec != -inf) t[pos].sec += k;
		t[pos].tag1 += k;t[pos].tag2 += k;
		t[pos].tag3 = max(t[pos].tag3,t[pos].tag1);
		t[pos].tag4 = max(t[pos].tag4,t[pos].tag2); 
		return;
	}
	pushdown(pos,l,r);
	int mid = (l + r) >> 1;
	if(L <= mid) modify_add(l,mid,L,R,k,pos << 1);
	if(R > mid) modify_add(mid + 1,r,L,R,k,pos << 1 | 1);
	pushup(pos);
}
inline void modify_min(int l,int r,int L,int R,int k,int pos)
{
	if(k >= t[pos].maxn) return;
	if(L <= l && r <= R && k > t[pos].sec)
	{
		t[pos].sum -= 1ll * t[pos].cnt * (t[pos].maxn - k);
		t[pos].tag2 -= t[pos].maxn - k;
		t[pos].maxn = k;
		return;
	}
	pushdown(pos,l,r);
	int mid = (l + r) >> 1;
	if(L <= mid) modify_min(l,mid,L,R,k,pos << 1);
	if(R > mid) modify_min(mid + 1,r,L,R,k,pos << 1 | 1);
	pushup(pos);
}
inline long long query_sum(int l,int r,int L,int R,int pos)
{
	if(L <= l && r <= R) return t[pos].sum;
	pushdown(pos,l,r);
	int mid = (l + r) >> 1;
	long long ret = 0;
	if(L <= mid) ret += query_sum(l,mid,L,R,pos << 1);
	if(R > mid) ret += query_sum(mid + 1,r,L,R,pos << 1 | 1);
	pushup(pos);
	return ret;
}
inline int query_max(int l,int r,int L,int R,int pos)
{
	if(L <= l && r <= R) return t[pos].maxn;
	pushdown(pos,l,r);
	int mid = (l + r) >> 1,ret = -inf;
	if(L <= mid) ret = max(ret,query_max(l,mid,L,R,pos << 1));
	if(R > mid) ret = max(ret,query_max(mid + 1,r,L,R,pos << 1 | 1));
	pushup(pos);
	return ret;
}
inline int query_hismax(int l,int r,int L,int R,int pos)
{
	if(L <= l && r <= R) return t[pos].hismax;
	pushdown(pos,l,r);
	int mid = (l + r) >> 1,ret = -inf;
	if(L <= mid) ret = max(ret,query_hismax(l,mid,L,R,pos << 1));
	if(R > mid) ret = max(ret,query_hismax(mid + 1,r,L,R,pos << 1 | 1));
	pushup(pos);
	return ret;
}
int main()
{
	ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
	cin>>n>>m;
	build(1,n,1);
	int op,l,r,k;
	for(int i = 1;i <= m;i++)
	{
		cin>>op>>l>>r;
		switch(op)
		{
			case 1:
				cin>>k;
				modify_add(1,n,l,r,k,1);
				break;
			case 2:
				cin>>k;
				modify_min(1,n,l,r,k,1);
				break;
			case 3:
				cout<<query_sum(1,n,l,r,1)<<endl;
				break;
			case 4:
				cout<<query_max(1,n,l,r,1)<<endl;
				break;
			case 5:
				cout<<query_hismax(1,n,l,r,1)<<endl;
				break;
		}
	}
	return 0;
}

"所有的真理,都是符合客觀事實的,更是順著邏輯的。"