思路
首先,看到区间修改区间查询应该想到线段树。
然后显然,我们可以用线段树快速求出 ∑i=lrai∣i−x∣ 的值。
每个节点可以维护 3 个信息:
lsum 的值为 ∑i=lr(r−i+1)×ai;
rsum 的值为 ∑i=lr(i−l+1)×ai;
sum 的值为 ∑i=lrai。
区间加时:
lsum=lsum+2siz(siz+1)x
rsum=rsum+2siz(siz+1)x
sum=sum+siz×x。
合并时:
lsum=lsonlsum+lsonsum×rsonsiz+rsonlsum;
rsum=lsonrsum+rsonsum×lsonsiz+rsonrsum;
sum=lsonsum+rsonsum。
为什么这样合并?以 lsum 举例:
lsonlsum+lsonsum×rsonsiz+rsonlsum
=i=lsonl∑lsonr(lsonr−i+1)×ai+i=lsonl∑lsonrrsonsiz×ai+i=rsonl∑rsonr(rsonr−i+1)×ai
根据乘法分配律,得:
i=lsonl∑lsonr(lsonr−i+1+rsonsiz)×ai+i=rsonl∑rsonr(rsonr−i+1)×ai
因为线段树上两个子节点存储的区间是相邻的,即 rsonl=lsonr+1。
显然,l=lsonl,r=rsonr。得:
i=l∑lsonr(lsonr−i+1+(r−(lsonr+1)+1))×ai+i=lsonr+1∑r(r−i+1)×ai
i=l∑lsonr(lsonr−i+1+(r−lsonr−1+1))×ai+i=lsonr+1∑r(r−i+1)×ai
i=l∑lsonr(lsonr−i+1+r−lsonr)×ai+i=lsonr+1∑r(r−i+1)×ai
i=l∑lsonr(lsonr−lsonr−i+1+r)×ai+i=lsonr+1∑r(r−i+1)×ai
i=l∑lsonr(r−i+1)×ai+i=lsonr+1∑r(r−i+1)×ai
这两个 ∑ 可以合并在一起,变成:
i=l∑r(r−i+1)×ai
rsum 也是类似的。
这样,我们就可以成功地查询任意一个区间的 lsum 和 rsum 了!
查询 ∑i=lrai∣i−x∣,就相当于查询 [l,x) 的 lsum+(x,r] 的 rsum。
然后,我们就可以在 [l,r] 里暴力枚举 x,时间复杂度 O(nmlogn)。
我们考虑继续优化刚才的算法。
首先,对于 l≤x<r,将 x 移到 x+1,∑i=lrai∣i−x∣ 会增加 ∑i=lxai−∑i=x+1rai。
这个容易理解,x 移到 x+1 后,[l,x] 离 x 的距离都增加 1,∑i=lrai∣i−x∣ 取值增加 ∑i=lxai;[x+1,r] 离 x 的距离都减少 1,∑i=lrai∣i−x∣ 取值减少 ∑i=x+1rai。
然后,因为 ai 全是正整数,肯定会有一个 d∈[l,r],使得 x∈[l,d] 时 ∑i=lxai−∑i=x+1rai 全是负数,x∈[d+1,r] 时 ∑i=lxai−∑i=x+1rai 全是非负数。
代价先下降后增长,这不是个单谷函数吗?
谷底的取值,就在 ∑i=lrai∣i−x∣ 值第一次上升前的位置,也就是 ∑i=lxai−∑i=x+1rai 最接近 0 的位置,使 ∑i=lxai 最接近 2∑i=lrai 的位置。
于是,我们可以在线段树上根据左右子树的和查找那个位置。
时间复杂度为 O((n+m)logn)。
代码
#include <algorithm>
#include <cctype>
#include <cstdio>
#define inl inline
#define ll long long
#define reg register
using namespace std;
inl ll read(){
reg char ch = getchar();reg ll res = 0, f = 1;
while (ch > '9' || ch < '0'){if (ch == '-') f = -1;ch = getchar();}
while (ch >= '0' && ch <= '9') res = (res << 3) + (res << 1) + (ch ^ 48), ch = getchar();
return res * f;
} // 快读
struct node{
int siz;
ll lsum, rsum, sum;
}tr[4000001];
ll tag[4000001];
inl node merge(node a, node b){
if (!a.lsum && !a.rsum && !a.sum) return b;
if (!b.lsum && !b.rsum && !b.sum) return a;
reg node ans = {a.siz + b.siz, a.lsum + a.sum * b.siz + b.lsum, a.rsum + b.sum * a.siz + b.rsum, a.sum + b.sum};
return ans;
} // 合并信息
inl void addtag(int u, ll x){
tag[u] += x;
tr[u].lsum += (1ll * tr[u].siz * tr[u].siz + tr[u].siz >> 1) * x;
tr[u].rsum += (1ll * tr[u].siz * tr[u].siz + tr[u].siz >> 1) * x;
tr[u].sum += 1ll * tr[u].siz * x;
} // 打标记
inl void downtag(int u){
if (!tag[u]) return;
addtag(u << 1, tag[u]);
addtag(u << 1 | 1, tag[u]);
tag[u] = 0;
} // 下放
inl void update(int u, int l, int r, int L, int R, int x){
if (r < L || R < l) return;
if (L <= l && r <= R){
addtag(u, x);
return;
}
downtag(u);
reg int mid = l + r >> 1;
update(u << 1, l, mid, L, R, x);
update(u << 1 | 1, mid + 1, r, L, R, x);
tr[u] = merge(tr[u << 1], tr[u << 1 | 1]);
} // 区间加
inl node query(int u, int l, int r, int L, int R){
if (r < L || R < l) return {0, 0, 0, 0};
if (L <= l && r <= R) return tr[u];
downtag(u);
reg int mid = l + r >> 1;
return merge(query(u << 1, l, mid, L, R), query(u << 1 | 1, mid + 1, r, L, R));
} // 区间查询
int n, m, a[1000001], op, l, r, x;
inl ll calc(int l, int x, int r){
if (x == l) return query(1, 1, n, l + 1, r).rsum;
if (x == r) return query(1, 1, n, l, r - 1).lsum;
return query(1, 1, n, l, x - 1).lsum + query(1, 1, n, x + 1, r).rsum;
} // 计算答案
inl ll solve(int L, int R){
reg ll now = 0, suml = query(1, 1, n, 1, L - 1).sum, sumr = query(1, 1, n, 1, R).sum, sum;
sum = sumr - suml;
reg int u = 1, l = 1, r = n;
while (r > l){
downtag(u); // 注意,要先下放标记
reg int mid = l + r >> 1;
if (mid >= R){
u = u << 1, r = mid;
continue;
} // 如果不能往右
if (mid < L){
now += tr[u << 1].sum, u = u << 1 | 1, l = mid + 1;
continue;
} // 如果不能往左
if (now + tr[u << 1].sum - suml << 1 > sum) u = u << 1, r = mid;
else now += tr[u << 1].sum, u = u << 1 | 1, l = mid + 1;
}
return calc(L, l, R);
} // 线段树上二分
inl void build(int u, int l, int r){
if (l == r){
tr[u] = {1, a[l], a[l], a[l]};
return;
}
reg int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
tr[u] = merge(tr[u << 1], tr[u << 1 | 1]);
} // 建树
ll lst;
int main(){
n = read(), m = read();
for (reg int i = 1;i <= n;i ++) a[i] = read();
build(1, 1, n);
while (m --){
op = read(), l = read(), r = read();
if (op == 1) x = read(), update(1, 1, n, l, r, lst % 1000001 + x);
else lst = solve(l, r);
}
printf("%lld", lst);
}