P5161 题解

First Post:

Last Update:

这题属于是对字符串的复习,也是对“撒关键点” 的复习。

提前计算长度为 的答案,即

把原数组变为差分数组,原问题变为间隔至少为 的子串对个数。

那么答案就是

这个间隔至少为 不好处理,我们不妨用 ,减去相邻或相交的子串对数。

这个 是好做的,用 SA 跑出 height 之后用单调栈解决即可。

考虑这个 "相邻或相交的子串对数" 实际上就是对于 ,求 的和。

不妨枚举 ,考察 的贡献。

可以仿照 [NOI2016] 优秀的拆分 的做法,在串上每隔 个位置撒一个关键点。

那么对于这个 ,它就至少会经过一个关键点。

不妨画图如下:

图中 是一个关键点,红色的是 所代表的串。我们计算 的答案。

容易发现,在 之后的红串,实际上就是

我们令

,也就是红串的前半段的长度。

那么一个 带来的贡献就是

考虑 的范围。

显然会有 (要不然前半段就不等了)

然后还会有 (否则红串长度就小于 了)

确定范围之后,上面要算的就是等差数列求和,直接计算即可。

都可以用 SA 求。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 3e5 + 5;
map<int,int> mp;
int n;
int a[N];
struct SA{
int sa[N],ton[N];
int t1[N],t2[N];
int s[N];
inline bool Equ(int *r,int i,int j,int l) { return r[i] == r[j] && r[min(n + 1,i + l)] == r[min(n + 1,j + l)];}
#define For(i,a,b) for(int i = (a);i <= (b);i++)
#define Rof(i,a,b) for(int i = (a);i >= (b);i--)
inline void get_sa(int n,int m)
{
int *x = t1,*y = t2,*t;
For(i,1,m) ton[i] = 0;
For(i,1,n) ton[x[i] = s[i]]++;
For(i,1,m) ton[i] += ton[i - 1];
Rof(i,n,1) sa[ton[x[i]]--] = i;
for(int j = 1,p;j <= n;j *= 2,m = p)
{
p = 0;
For(i,n - j + 1,n) y[++p] = i;
For(i,1,n) if(sa[i] > j) y[++p] = sa[i] - j;
For(i,1,m) ton[i] = 0;
For(i,1,n) ton[x[i]]++;
For(i,1,m) ton[i] += ton[i - 1];
Rof(i,n,1) sa[ton[x[y[i]]]--] = y[i],y[i] = 0;
t = x;x = y;y = t;p = 1;x[sa[1]] = 1;
For(i,2,n)
x[sa[i]] = Equ(y,sa[i - 1],sa[i],j) ? p : (++p);
if(p >= n) break;
}
}
int rk[N],height[N];
int ST[19][N],lg[N];
inline void Get_Hi()
{
For(i,1,n) rk[sa[i]] = i;
int j = 0;
For(i,1,n)
{
if(j) --j;
if(rk[i] != 1)
while(s[i + j] == s[sa[rk[i] - 1] + j]) ++j;
height[rk[i]] = j;
}
lg[0] = -1;
For(i,1,n) lg[i] = lg[i >> 1] + 1;
For(i,1,n) ST[0][i] = height[i];
For(j,1,18)
For(i,1,n - (1 << j) + 1)
ST[j][i] = min(ST[j - 1][i],ST[j - 1][i + (1 << j - 1)]);
}
inline int lcp(int x,int y)
{
x = rk[x];y = rk[y];
if(x > y) swap(x,y);++x;
int k = lg[y - x + 1];
return min(ST[k][x],ST[k][y-(1<<k)+1]);
}
#undef For
#undef Rof
}A,B;
inline long long Sum(int l,int r) { return 1ll * (l + r) * (r - l + 1) / 2;}
inline long long Calc(int s,int p,int len)
{
int l = len - min(len,s);
int r = min(len,p) - 1;
if(l <= r)
return Sum(l,r) + 1ll * (s - len + 1) * (r - l + 1);
else return 0;
}
long long f[N];
int stk[N],top;
int main()
{
cin >> n;
long long ans = 1ll * n * (n - 1) / 2;
for(int i = 1;i <= n;i++)
cin >> a[i];
for(int i = 1;i < n;i++) a[i] = a[i + 1] - a[i];
--n;
for(int i = 1,tot = 0;i <= n;i++)
if(mp.find(a[i]) == mp.end()) a[i] = mp[a[i]] = ++tot;
else a[i] = mp[a[i]];
for(int i = 1;i <= n;i++)
A.s[i] = a[i],B.s[i] = a[n - i + 1];
A.get_sa(n,n);
B.get_sa(n,n);
A.s[n + 1] = B.s[n + 1] = 0;
A.Get_Hi();
B.Get_Hi();

for(int i = 1;i <= n;i++)
{
while(top && A.height[stk[top]] >= A.height[i]) --top;
if(top) f[i] = f[stk[top]] + 1ll * (i - stk[top]) * A.height[i];
else f[i] = 1ll * i * A.height[i];
ans += f[i];
stk[++top] = i;
}
for(int len = 1;len <= n;len++)
for(int i = len;i + len <= n;i += len)
{
int j = i + len;
ans -= Calc(B.lcp(n - i + 1,n - j + 1),A.lcp(i,j),len);
}
cout << ans << endl;
return 0;
}