このデータ構造の実装には、 yaketake08さんの解説 を大変参考にさせていただきました。
/*
<Segment Tree Beats>
様々なクエリに対応した万能遅延セグメント木(0-indexed)
tree.update_max(l,r,x):区間[l,r]内のiについてa[i]をmax(a[i],x)に更新
tree.update_min(l,r,x):区間[l,r]内のiについてa[i]をmin(a[i],x)に更新
tree.query_add(l,r,x):区間[l,r]内のiについてa[i]にxを加算
tree.query_val(l,r,x):区間[l,r]内のiについてa[i]をxに更新
tree.query_max(l,r):区間[l,r]に対する最大値クエリ
tree.query_min(l,r):区間[l,r]に対する最小値クエリ
tree.query_sum(l,r):区間[l,r]に対する区間和クエリ
*/
template <typename T>
struct Segment_Tree_Beats
{
private:
const T INF = numeric_limits<T>::max();
int n, n0;
vector<T> fmax; //最大値
vector<T> smax; //2番目に大きい値
vector<T> fmax_c; //最大値の個数
vector<T> fmin; //最小値
vector<T> smin; //2番目に小さい値
vector<T> fmin_c; //最小値の個数
vector<T> sum; //区間和
vector<T> len; //区間の長さ
vector<T> ladd; //遅延加算
vector<T> lval; //遅延更新
void update_node_max(int k, T x)
{
sum[k] += (x - fmax[k]) * fmax_c[k];
if (fmax[k] == fmin[k])
{
fmax[k] = x;
fmin[k] = x;
}
else
{
if (fmax[k] == smin[k])
{
fmax[k] = x;
smin[k] = x;
}
else
fmax[k] = x;
}
//遅延処理
if (lval[k] != INF && x < lval[k])
lval[k] = x;
}
void update_node_min(int k, T x)
{
sum[k] += (x - fmin[k]) * fmin_c[k];
if (fmin[k] == fmax[k])
{
fmin[k] = x;
fmax[k] = x;
}
else
{
if (fmin[k] == smax[k])
{
fmin[k] = x;
smax[k] = x;
}
else
fmin[k] = x;
}
//遅延処理
if (lval[k] != INF && lval[k] < x)
lval[k] = x;
}
//全体加算
void addall(int k, T x)
{
fmax[k] += x;
fmin[k] += x;
if (smax[k] != -INF)
smax[k] += x;
if (smin[k] != INF)
smin[k] += x;
sum[k] += len[k] * x;
if (lval[k] != INF)
lval[k] += x;
else
ladd[k] += x;
}
//全体更新
void updateall(int k, T x)
{
fmax[k] = x;
fmin[k] = x;
smax[k] = -INF;
smin[k] = INF;
fmax_c[k] = len[k];
fmin_c[k] = len[k];
sum[k] = x * len[k];
lval[k] = x;
ladd[k] = 0;
}
//親->子への伝播
void push(int k)
{
if ((k - n0 + 1) >= 0)
return;
//遅延伝播
if (lval[k] != INF)
{
updateall(2 * k + 1, lval[k]);
updateall(2 * k + 2, lval[k]);
lval[k] = INF;
return;
}
if (ladd[k] != 0)
{
addall(2 * k + 1, ladd[k]);
addall(2 * k + 2, ladd[k]);
ladd[k] = 0;
}
for (int i = 1; i <= 2; i++)
{
if (fmax[k] < fmax[2 * k + i])
update_node_max(2 * k + i, fmax[k]);
if (fmin[k] > fmin[2 * k + i])
update_node_min(2 * k + i, fmin[k]);
}
}
//子->親への伝播
void update(int k)
{
sum[k] = sum[2 * k + 1] + sum[2 * k + 2];
//最大値関連の更新(大きい方を優先)
if (fmax[2 * k + 1] < fmax[2 * k + 2])
{
fmax[k] = fmax[2 * k + 2];
fmax_c[k] = fmax_c[2 * k + 2];
smax[k] = max(fmax[2 * k + 1], smax[2 * k + 2]);
}
if (fmax[2 * k + 1] > fmax[2 * k + 2])
{
fmax[k] = fmax[2 * k + 1];
fmax_c[k] = fmax_c[2 * k + 1];
smax[k] = max(fmax[2 * k + 2], smax[2 * k + 1]);
}
if (fmax[2 * k + 1] == fmax[2 * k + 2])
{
fmax[k] = fmax[2 * k + 1];
fmax_c[k] = fmax_c[2 * k + 1] + fmax_c[2 * k + 2];
smax[k] = max(smax[2 * k + 1], smax[2 * k + 2]);
}
//最小値関連の更新(小さい方を優先)
if (fmin[2 * k + 1] < fmin[2 * k + 2])
{
fmin[k] = fmin[2 * k + 1];
fmin_c[k] = fmin_c[2 * k + 1];
smin[k] = min(fmin[2 * k + 2], smin[2 * k + 1]);
}
if (fmin[2 * k + 1] > fmin[2 * k + 2])
{
fmin[k] = fmin[2 * k + 2];
fmin_c[k] = fmin_c[2 * k + 2];
smin[k] = min(fmin[2 * k + 1], smin[2 * k + 2]);
}
if (fmin[2 * k + 1] == fmin[2 * k + 2])
{
fmin[k] = fmin[2 * k + 1];
fmin_c[k] = fmin_c[2 * k + 1] + fmin_c[2 * k + 2];
smin[k] = max(smin[2 * k + 1], smin[2 * k + 2]);
}
}
//区間[a,b)内のiについてa[i]をmax(a[i],x)に更新
void _update_max(T x, int a, int b, int k, int l, int r)
{
//区間をはみ出すorxが最小値以下
if (b <= l || r <= a || fmin[k] >= x)
return;
//(最小値)<x<(2番目に小さい値)のとき
if (l >= a && r <= b && smin[k] > x)
{
update_node_min(k, x);
return;
}
//親->子->親
push(k);
_update_max(x, a, b, 2 * k + 1, l, (l + r) / 2);
_update_max(x, a, b, 2 * k + 2, (l + r) / 2, r);
update(k);
}
//区間[a,b)内のiについてa[i]をmin(a[i],x)に更新
void _update_min(T x, int a, int b, int k, int l, int r)
{
//区間をはみ出すorxが最大値以上
if (b <= l || r <= a || fmax[k] <= x)
return;
//(2番目に大きい値)<x<(最大値)のとき
if (l >= a && r <= b && smax[k] < x)
{
update_node_max(k, x);
return;
}
//親->子->親
push(k);
_update_min(x, a, b, 2 * k + 1, l, (l + r) / 2);
_update_min(x, a, b, 2 * k + 2, (l + r) / 2, r);
update(k);
}
//区間[a,b)に対する最大値クエリ
T _query_max(int a, int b, int k, int l, int r)
{
if (b <= l || a >= r)
return -INF; //区間外
if (a <= l && r <= b)
return fmax[k];
push(k);
T L = _query_max(a, b, 2 * k + 1, l, (l + r) / 2);
T R = _query_max(a, b, 2 * k + 2, (l + r) / 2, r);
return max(L, R);
}
//区間[a,b)に対する最小値クエリ
T _query_min(int a, int b, int k, int l, int r)
{
if (b <= l || a >= r)
return INF; //区間外
if (a <= l && r <= b)
return fmin[k];
push(k);
T L = _query_min(a, b, 2 * k + 1, l, (l + r) / 2);
T R = _query_min(a, b, 2 * k + 2, (l + r) / 2, r);
return min(L, R);
}
//区間[a,b)に対する区間和クエリ
T _query_sum(int a, int b, int k, int l, int r)
{
if (b <= l || a >= r)
return 0; //区間外
if (a <= l && r <= b)
return sum[k];
push(k);
T L = _query_sum(a, b, 2 * k + 1, l, (l + r) / 2);
T R = _query_sum(a, b, 2 * k + 2, (l + r) / 2, r);
return (L + R);
}
//区間[a,b)に対する区間加算クエリ
void _query_add(T x, int a, int b, int k, int l, int r)
{
if (b <= l || a >= r)
return; //区間外
if (a <= l && r <= b)
{
addall(k, x);
return;
}
push(k);
_query_add(x, a, b, 2 * k + 1, l, (l + r) / 2);
_query_add(x, a, b, 2 * k + 2, (l + r) / 2, r);
update(k);
}
//区間[a,b)に対する区間更新クエリ
void _query_val(T x, int a, int b, int k, int l, int r)
{
if (b <= l || a >= r)
return; //区間外
if (a <= l && r <= b)
{
updateall(k, x);
return;
}
push(k);
_query_val(x, a, b, 2 * k + 1, l, (l + r) / 2);
_query_val(x, a, b, 2 * k + 2, (l + r) / 2, r);
update(k);
}
public:
Segment_Tree_Beats(int N) : n(N), fmax(4 * N), smax(4 * N), fmax_c(4 * N), fmin(4 * N), smin(4 * N),
fmin_c(4 * N), sum(4 * N), len(4 * N), ladd(4 * N), lval(4 * N)
{
n0 = 1;
while (n0 < n)
n0 *= 2;
len[0] = n0;
for (int i = 0; i < 2 * n0; i++)
{
lval[i] = INF;
}
for (int i = 0; i < n0 - 1; i++)
{
len[2 * i + 1] = len[i] / 2;
len[2 * i + 2] = len[i] / 2;
}
for (int i = 0; i < n; i++)
{
int nn = n0 - 1 + i;
smax[nn] = -INF;
smin[nn] = INF;
fmax_c[nn] = 1;
fmin_c[nn] = 1;
}
for (int i = n; i < n0; i++)
{
int nn = n0 - 1 + i;
fmax[nn] = -INF;
smax[nn] = -INF;
fmin[nn] = INF;
smin[nn] = INF;
}
for (int i = n0 - 2; i >= 0; i--)
{
update(i);
}
}
Segment_Tree_Beats(int N, vector<T> &A) : n(N), fmax(4 * N), smax(4 * N), fmax_c(4 * N), fmin(4 * N), smin(4 * N),
fmin_c(4 * N), sum(4 * N), len(4 * N), ladd(4 * N), lval(4 * N)
{
n0 = 1;
while (n0 < n)
n0 *= 2;
len[0] = n0;
for (int i = 0; i < 2 * n0; i++)
{
lval[i] = INF;
}
for (int i = 0; i < n0 - 1; i++)
{
len[2 * i + 1] = len[i] / 2;
len[2 * i + 2] = len[i] / 2;
}
for (int i = 0; i < n; i++)
{
int nn = n0 - 1 + i;
fmax[nn] = A[i];
fmin[nn] = A[i];
sum[nn] = A[i];
smax[nn] = -INF;
smin[nn] = INF;
fmax_c[nn] = 1;
fmin_c[nn] = 1;
}
for (int i = n; i < n0; i++)
{
int nn = n0 - 1 + i;
fmax[nn] = -INF;
smax[nn] = -INF;
fmin[nn] = INF;
smin[nn] = INF;
}
for (int i = n0 - 2; i >= 0; i--)
{
update(i);
}
}
//区間[l,r]内のiについてa[i]をmax(a[i],x)に更新
void update_max(int l, int r, T x)
{
_update_max(x, l, r + 1, 0, 0, n0);
}
//区間[l,r]内のiについてa[i]をmin(a[i],x)に更新
void update_min(int l, int r, T x)
{
_update_min(x, l, r + 1, 0, 0, n0);
}
//区間[l,r]内のiについてa[i]にxを加算
void query_add(int l, int r, T x)
{
_query_add(x, l, r + 1, 0, 0, n0);
}
//区間[l,r]内のiについてa[i]をxに更新
void query_val(int l, int r, T x)
{
_query_val(x, l, r + 1, 0, 0, n0);
}
//区間[l,r]に対する最大値クエリ
T query_max(int l, int r)
{
return _query_max(l, r + 1, 0, 0, n0);
}
//区間[l,r]に対する最小値クエリ
T query_min(int l, int r)
{
return _query_min(l, r + 1, 0, 0, n0);
}
//区間[l,r]に対する区間和クエリ
T query_sum(int l, int r)
{
return _query_sum(l, r + 1, 0, 0, n0);
}
};