HNOI 2019 序列 题解

First Post:

Last Update:

保序回归的模型容易发现,但我对单调性这类东西的直觉可能还不够敏锐。

首先考虑没有修改怎么做。根据保序回归的结论,当 时,必有 。而且如果一段连续的 相等,那么这个 取这段中 的平均数是最优的。所以可以考虑用一个单调栈来维护所有 相等的段。每次在后面新加入一个元素 时,如果栈顶的均值比 大,那么就弹出栈顶,将栈顶那一段与 合并,直到栈顶的均值小于 所在那一段的均值,然后将 这一段入栈。统计答案是容易的。这样就可以在 的时间内完成对单个序列的求解。

在有修改的时候,因为每次修改独立,且只修改一个位置,容易想到预处理序列前缀后缀的信息,回答询问时再拼起来。具体地,对于修改 ,最优方案肯定是保留 相等段划分的一个前缀,保留 相等段划分的一个后缀,剩下的部分与 同处一段。 设 表示在 中从左到右第 段对应的平均值, 表示在 中从左到右第 段对应的平均值,设 表示 的第 段与 的第 段中间夹着的所有数的平均值。那么一对 合法当且仅当:

  1. 段与 段中的数确实可以合并成一段(即如果对这些数跑上文的保序回归算法,最后的栈中一定只会剩下一段)

其实这样的 是唯一的,因为这个保序回归问题的解是唯一的,因为我们如果把 看成两个 维向量,那代价函数就是它们在 维空间中的欧氏距离,而不等式组 的解空间是个凸集,所以最优的 是唯一的。

那么如何找到这样一对 ?可以先枚举 ,找到最大的 满足 ,设为 。这一步可用二分实现,因为如果 满足该条件, 也一定满足。另一方面,对于 ,数对 虽然肯定满足 ,但是它们肯定不满足条件 2。所以对于一个 ,我们只需关注 是否合法,即 是否成立。

对于 合法的 ,我们要选取哪个 呢?答案是选取最小的 ,原因也是基本类似的,更小的 会让条件 1 不成立,更大的 会并上一些不该合并的段,让条件 2 不成立。

另一方面,找这个 也是可以二分的:

首先有观察:这是因为,对一个 求出的 ,就是使 取到最大值的

所以,如果设 求出的 ,设 求出的 ,那么 ,因为 都小于 ,所以 。所以, 满足条件可以推出 满足条件。

于是,我们可以先在 的单调栈上二分 ,检查 的合法性时在 的单调栈上二分 ,就可以找到这一对 。这需要我们使用数据结构维护每个前缀和每个后缀的单调栈。有很多人写了主席树,其实这没有必要。我们把每个元素向它加入栈中时的上一个元素连边,就会形成一棵树,某个前缀的单调栈就是树上从这个点开始到根的路径,用树上倍增代替二分即可。代码很好写。

于是本题就做完了,时间复杂度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5,P = 998244353,Lg = 18;
typedef long long ll;
typedef double db;
int n,m,inv[N];
int a[N];
ll sum[N];
int s2[N];
struct Node {
ll s;int c,l,r;
Node(){}
Node(const ll _s,const int _c,const int _l,const int _r):s(_s),c(_c),l(_l),r(_r){}
Node operator + (const Node &rhs) const { return Node(s + rhs.s,c + rhs.c,min(l,rhs.l),max(r,rhs.r));}
bool operator < (const Node &rhs) const { return (__int128) s * rhs.c < (__int128) rhs.s * c;}
};
inline int Getval(int l,int r) {
int res = (s2[r] - s2[l - 1] + P) % P,sm = (sum[r] - sum[l - 1]) % P;
res = (res + P - 1ll * sm * sm % P * inv[r - l + 1] % P) % P;
return res;
}
inline int Getvc(int l,int r,int x,int y) {
int res = (s2[r] - s2[l - 1] + P) % P,sm = (sum[r] - sum[l - 1]) % P;
res = (res - 1ll * x * x % P + 1ll * y * y % P) % P;
sm = (sm - x + y) % P;
if(res < 0) res += P;if(sm < 0) sm += P;
res = (res + P - 1ll * sm * sm % P * inv[r - l + 1] % P) % P;
return res;
}
inline Node Getave(int l,int r) { if(l > r) return Node(0,0,0,0);return Node(sum[r] - sum[l - 1],r - l + 1,l,r);}
int pre[N],suf[N]; // 用树维护单调栈
int ac1[Lg][N],ac2[Lg][N];
Node t1[N],t2[N];
Node sum1[Lg][N];
int dep1[N],dep2[N];
int fv[N],gv[N];
inline void Build() {
t1[0] = Node(0,1,0,0);t2[n + 1] = Node(1e9,1,n+1,0);
for(int i = 1;i <= n;i++) {
Node now(a[i],1,i,i);
int p = i - 1;
while(p && now < t1[p]) now = now + t1[p],p = pre[p];
pre[i] = p;t1[i] = now;
dep1[i] = dep1[pre[i]] + 1;
fv[i] = (fv[pre[i]] + Getval(now.l,now.r)) % P;
}
for(int i = n;i >= 1;i--) {
Node now(a[i],1,i,i);
int p = i + 1;
while(p != n + 1 && t2[p] < now) now = now + t2[p],p = suf[p];
suf[i] = p;t2[i] = now;
dep2[i] = dep2[suf[i]] + 1;
gv[i] = (gv[suf[i]] + Getval(now.l,now.r)) % P;
}
for(int i = 1;i <= n;i++) {
ac1[0][i] = pre[i],ac2[0][i] = suf[i];
sum1[0][i] = t1[i];
}
for(int j = 1;j < Lg;j++)
for(int i = 1;i <= n;i++) {
ac1[j][i] = ac1[j - 1][ac1[j - 1][i]];
ac2[j][i] = ac2[j - 1][ac2[j - 1][i]];
sum1[j][i] = sum1[j - 1][i] + sum1[j - 1][ac1[j - 1][i]];
}
}
inline int jump(int op,int x,int k) {
for(int i = Lg - 1;i >= 0;i--)
if((k >> i) & 1) x = (op == 1) ? ac1[i][x] : ac2[i][x];
return x;
}
inline int GetL(int x,int y,int R) { // a[x] = y
if(x == 1) return 1;
Node rig = (R ? Getave(x + 1,t2[R].r) : Node(0,0,1e9,0));
if(t1[x - 1] < Node(y,1,x,x) + rig) return x;
if(t1[pre[x - 1]] < t1[x - 1] + Node(y,1,x,x) + rig) return x - 1;
int now = x - 1;
Node sm(0,0,0,0);
for(int i = Lg - 1;i >= 0;i--) {
Node val = sm + sum1[i][now];
int nx = ac1[i][now];
val = val + Node(y,1,x,x) + rig;
if(!(t1[nx] < val)) sm = sm + sum1[i][now],now = nx;
}
return now;
}
inline int Solve(int x,int y) {
int kl = 0,kr = 0;
int lef = 0,rig = dep2[x + 1] - 1;
while(lef < rig) {
int mid = lef + rig >> 1;
int now = jump(2,x + 1,mid);
int l0 = GetL(x,y,now),ll = l0 < x ? t1[l0].l : l0;
Node val = Getave(ll,x - 1) + Node(y,1,x,x) + Getave(x + 1,t2[now].r);
if(val < t2[suf[now]]) rig = mid;
else lef = mid + 1;
}
int tt = GetL(x,y,0),ttl = tt < x ? t1[tt].l : tt;
if(Getave(ttl,x - 1) + Node(y,1,x,x) < t2[x + 1]) kr = x;
else kr = t2[jump(2,x + 1,lef)].r;
kl = (kr == x) ? tt : GetL(x,y,jump(2,x + 1,lef));
if(kl < x) kl = t1[kl].l;
return (1ll * fv[kl - 1] + gv[kr + 1] + Getvc(kl,kr,a[x],y)) % P;
}
int main() {
cin >> n >> m;
for(int i = 1;i <= n;i++) cin >> a[i];
for(int i = 1;i <= n;i++)
sum[i] = sum[i - 1] + a[i],s2[i] = (s2[i - 1] + 1ll * a[i] * a[i] % P) % P;
inv[1] = 1;
for(int i = 2;i <= n;i++) inv[i] = 1ll * inv[P % i] * (P - P / i) % P;
Build();
cout << fv[n] << endl;
for(int i = 1;i <= m;i++) {
int x,y;
cin >> x >> y;
cout << Solve(x,y) << endl;
}
return 0;
}