思路

首先,看到区间修改区间查询应该想到线段树。

然后显然,我们可以用线段树快速求出 i=lraiix\sum_{i = l}^r a_i |i - x| 的值。

每个节点可以维护 33 个信息:

lsumlsum 的值为 i=lr(ri+1)×ai\sum_{i = l}^r (r - i + 1) \times a_i

rsumrsum 的值为 i=lr(il+1)×ai\sum_{i = l}^r (i - l + 1) \times a_i

sumsum 的值为 i=lrai\sum_{i = l}^r a_i

区间加时:

lsum=lsum+siz(siz+1)2xlsum = lsum + \dfrac{siz(siz + 1)}{2}x

rsum=rsum+siz(siz+1)2xrsum = rsum + \dfrac{siz(siz + 1)}{2}x

sum=sum+siz×xsum = sum + siz \times x

合并时:

lsum=lsonlsum+lsonsum×rsonsiz+rsonlsumlsum = lson_{lsum} + lson_{sum} \times rson_{siz} + rson_{lsum}

rsum=lsonrsum+rsonsum×lsonsiz+rsonrsumrsum = lson_{rsum} + rson_{sum} \times lson_{siz} + rson_{rsum}

sum=lsonsum+rsonsumsum = lson_{sum} + rson_{sum}

为什么这样合并?以 lsumlsum 举例:

lsonlsum+lsonsum×rsonsiz+rsonlsumlson_{lsum} + lson_{sum} \times rson_{siz} + rson_{lsum}

=i=lsonllsonr(lsonri+1)×ai+i=lsonllsonrrsonsiz×ai+i=rsonlrsonr(rsonri+1)×ai= \sum_{i = lson_l}^{lson_r} (lson_r - i + 1) \times a_i + \sum_{i = lson_l}^{lson_r} rson_{siz} \times a_i + \sum_{i = rson_l}^{rson_{r}} (rson_{r} - i + 1) \times a_i

根据乘法分配律,得:

i=lsonllsonr(lsonri+1+rsonsiz)×ai+i=rsonlrsonr(rsonri+1)×ai\sum_{i = lson_l}^{lson_r} (lson_r - i + 1 + rson_{siz}) \times a_i + \sum_{i = rson_l}^{rson_r} (rson_r - i + 1) \times a_i

因为线段树上两个子节点存储的区间是相邻的,即 rsonl=lsonr+1rson_l = lson_r + 1

显然,l=lsonll = lson_lr=rsonrr = rson_r。得:

i=llsonr(lsonri+1+(r(lsonr+1)+1))×ai+i=lsonr+1r(ri+1)×ai\sum_{i = l}^{lson_r} (lson_r - i + 1 + (r - (lson_r + 1) + 1)) \times a_i + \sum_{i = lson_r + 1}^{r} (r - i + 1) \times a_i

i=llsonr(lsonri+1+(rlsonr1+1))×ai+i=lsonr+1r(ri+1)×ai\sum_{i = l}^{lson_r} (lson_r - i + 1 + (r - lson_r - 1 + 1)) \times a_i + \sum_{i = lson_r + 1}^{r} (r - i + 1) \times a_i

i=llsonr(lsonri+1+rlsonr)×ai+i=lsonr+1r(ri+1)×ai\sum_{i = l}^{lson_r} (lson_r - i + 1 + r - lson_r) \times a_i + \sum_{i = lson_r + 1}^{r} (r - i + 1) \times a_i

i=llsonr(lsonrlsonri+1+r)×ai+i=lsonr+1r(ri+1)×ai\sum_{i = l}^{lson_r} (lson_r - lson_r - i + 1 + r) \times a_i + \sum_{i = lson_r + 1}^{r} (r - i + 1) \times a_i

i=llsonr(ri+1)×ai+i=lsonr+1r(ri+1)×ai\sum_{i = l}^{lson_r} (r - i + 1) \times a_i + \sum_{i = lson_r + 1}^{r} (r - i + 1) \times a_i

这两个 \sum 可以合并在一起,变成:

i=lr(ri+1)×ai\sum_{i = l}^r (r - i + 1) \times a_i

rsumrsum 也是类似的。

这样,我们就可以成功地查询任意一个区间的 lsumlsumrsumrsum 了!

查询 i=lraiix\sum_{i = l}^r a_i |i - x|,就相当于查询 [l,x)[l, x)lsum+(x,r]lsum + (x, r]rsumrsum

然后,我们就可以在 [l,r][l, r] 里暴力枚举 xx,时间复杂度 O(nmlogn)O(nm \log n)


我们考虑继续优化刚才的算法。

首先,对于 lx<rl \le x < r,将 xx 移到 x+1x + 1i=lraiix\sum_{i = l}^r a_i |i - x| 会增加 i=lxaii=x+1rai\sum_{i = l}^x a_i - \sum_{i = x + 1}^r a_i

这个容易理解,xx 移到 x+1x + 1 后,[l,x][l, x]xx 的距离都增加 11i=lraiix\sum_{i = l}^r a_i |i - x| 取值增加 i=lxai\sum_{i = l}^x a_i[x+1,r][x + 1, r]xx 的距离都减少 11i=lraiix\sum_{i = l}^r a_i |i - x| 取值减少 i=x+1rai\sum_{i = x + 1}^r a_i

然后,因为 aia_i 全是正整数,肯定会有一个 d[l,r]d \in [l, r],使得 x[l,d]x \in [l, d]i=lxaii=x+1rai\sum_{i = l}^x a_i - \sum_{i = x + 1}^r a_i 全是负数,x[d+1,r]x \in [d + 1, r]i=lxaii=x+1rai\sum_{i = l}^x a_i - \sum_{i = x + 1}^r a_i 全是非负数。

代价先下降后增长,这不是个单谷函数吗?

谷底的取值,就在 i=lraiix\sum_{i = l}^r a_i |i - x| 值第一次上升前的位置,也就是 i=lxaii=x+1rai\sum_{i = l}^x a_i - \sum_{i = x + 1}^r a_i 最接近 00 的位置,使 i=lxai\sum_{i = l}^x a_i 最接近 i=lrai2\dfrac{\sum_{i = l}^r a_i}{2} 的位置。

于是,我们可以在线段树上根据左右子树的和查找那个位置。

时间复杂度为 O((n+m)logn)O((n + m) \log n)

代码

#include <algorithm>
#include <cctype>
#include <cstdio>
#define inl inline
#define ll long long
#define reg register

using namespace std;

inl ll read(){
	reg char ch = getchar();reg ll res = 0, f = 1;
	while (ch > '9' || ch < '0'){if (ch == '-') f = -1;ch = getchar();}
	while (ch >= '0' && ch <= '9') res = (res << 3) + (res << 1) + (ch ^ 48), ch = getchar();
	return res * f;
} // 快读
struct node{
	int siz;
	ll lsum, rsum, sum;
}tr[4000001];
ll tag[4000001];
inl node merge(node a, node b){
	if (!a.lsum && !a.rsum && !a.sum) return b;
	if (!b.lsum && !b.rsum && !b.sum) return a;
	reg node ans = {a.siz + b.siz, a.lsum + a.sum * b.siz + b.lsum, a.rsum + b.sum * a.siz + b.rsum, a.sum + b.sum};
	return ans;
} // 合并信息
inl void addtag(int u, ll x){
	tag[u] += x;
	tr[u].lsum += (1ll * tr[u].siz * tr[u].siz + tr[u].siz >> 1) * x;
	tr[u].rsum += (1ll * tr[u].siz * tr[u].siz + tr[u].siz >> 1) * x;
	tr[u].sum += 1ll * tr[u].siz * x;
} // 打标记
inl void downtag(int u){
	if (!tag[u]) return;
	addtag(u << 1, tag[u]);
	addtag(u << 1 | 1, tag[u]);
	tag[u] = 0;
} // 下放
inl void update(int u, int l, int r, int L, int R, int x){
	if (r < L || R < l) return;
	if (L <= l && r <= R){
		addtag(u, x);
		return;
	}
	downtag(u);
	reg int mid = l + r >> 1;
	update(u << 1, l, mid, L, R, x);
	update(u << 1 | 1, mid + 1, r, L, R, x);
	tr[u] = merge(tr[u << 1], tr[u << 1 | 1]);
} // 区间加
inl node query(int u, int l, int r, int L, int R){
	if (r < L || R < l) return {0, 0, 0, 0};
	if (L <= l && r <= R) return tr[u];
	downtag(u);
	reg int mid = l + r >> 1;
	return merge(query(u << 1, l, mid, L, R), query(u << 1 | 1, mid + 1, r, L, R));
} // 区间查询
int n, m, a[1000001], op, l, r, x;
inl ll calc(int l, int x, int r){
	if (x == l) return query(1, 1, n, l + 1, r).rsum;
	if (x == r) return query(1, 1, n, l, r - 1).lsum;
	return query(1, 1, n, l, x - 1).lsum + query(1, 1, n, x + 1, r).rsum;
} // 计算答案
inl ll solve(int L, int R){
	reg ll now = 0, suml = query(1, 1, n, 1, L - 1).sum, sumr = query(1, 1, n, 1, R).sum, sum;
	sum = sumr - suml;
	reg int u = 1, l = 1, r = n;
	while (r > l){
		downtag(u); // 注意,要先下放标记
		reg int mid = l + r >> 1;
		if (mid >= R){
			u = u << 1, r = mid;
			continue;
		} // 如果不能往右
		if (mid < L){
			now += tr[u << 1].sum, u = u << 1 | 1, l = mid + 1;
			continue;
		} // 如果不能往左
		if (now + tr[u << 1].sum - suml << 1 > sum) u = u << 1, r = mid;
		else now += tr[u << 1].sum, u = u << 1 | 1, l = mid + 1;
	}
	return calc(L, l, R);
} // 线段树上二分
inl void build(int u, int l, int r){
	if (l == r){
		tr[u] = {1, a[l], a[l], a[l]};
		return;
	}
	reg int mid = l + r >> 1;
	build(u << 1, l, mid);
	build(u << 1 | 1, mid + 1, r);
	tr[u] = merge(tr[u << 1], tr[u << 1 | 1]);
} // 建树
ll lst;
int main(){
	n = read(), m = read();
	for (reg int i = 1;i <= n;i ++) a[i] = read();
	build(1, 1, n);
	while (m --){
		op = read(), l = read(), r = read();
		if (op == 1) x = read(), update(1, 1, n, l, r, lst % 1000001 + x);
		else lst = solve(l, r);
	}
	printf("%lld", lst); 
}