远古 AGC 还是比现代 AGC (2023) 好做太多了。我觉得在模拟赛 eat shit
之余,还是需要这种思维体操活跃脑子。
首先,题意相当于让我们计数有多少个长度为 的 RB 串,使得其能被划分为 个子序列,每个子序列中要么 R 比 B
多,要么 R 与 B 相等且最后一个元素是 B。
考虑如何找到一种最优的划分策略,使得我们能把所有有解的串都判定出来。
容易知道,整个串中的 R 不会少于 B。我们枚举 R 比 B 多几个,设为 。
设 类子序列为其中 R 的个数多于
B 的序列, 类子序列为 R 的个数与 B
相等的序列。
最优的决策肯定是让
个子序列满足 R 比 B 恰好多 ,我们还要找到 个 类子序列。实际上,让每个 类子序列的长度都 一定是最优的,因为像 RBRB
或 RRBB
这样的串都能被拆分为两个
RB
,从而满足更多子序列的要求。但这么划分之后,整串中可能还有一些元素没有被划分进某个子序列,那么我们直接把这些元素丢进任意一个
类子序列即可,因为这样并不会改变这个子序列 R 比 B
多的现状,也不会不满足其它的要求。事实上 类子序列就是一个垃圾桶,其可以在划分完
类子序列后,接收剩下的所有元素。
读者可能会发现上文分析的问题:如果没有 类子序列怎么办?我们将 和 分开考虑。先不妨设 。
此时我们只要在原串中找到
个 RB
子序列。考虑如何求出原串最多能找到多少个。将 R 看作
,B 看作 ,从左往右推,维护一个变量 。设当前元素为 ,令 。此后,如果 且该位置为 B,那么其肯定能在前面找到一个 R 与其匹配。如果
,令 ,然后处理下一个字符。容易证明这个算法求出来的 B
类子序列个数是最多的。
但只有这个算法还是没法计数。考虑把原串看作二维平面上从 出发的一条折线,一个 R 和一个 B
分别代表向右上方走一步和向右下方走一步。每个位置的 其实就是折线上每个整点的
坐标,只不过,我们强制这条折线在向下越过直线
时会被扳回到这条直线上。上述算法的结果其实就是 。
如何算将折线扳回
轴的次数?我们不妨让折线在穿过
轴之后还能向下走,此时折线最后会停在位置 。观察原来每一次扳回 轴时的位置,发现其对应的就是折线上
坐标的前缀最小值位置!这个事实的正确性容易理解。进一步地,由于
坐标每次只会变化 或 ,所以相邻两个前缀最小值的差也是 。所以,前缀最小值的个数,就等于最后的前缀最小值,也就是整条折线上
坐标最小值的相反数!
我们设 为将折线扳回 轴的次数。因为 ,所以
,设 ,那么
,也就是说转化后的折线不会向下穿过直线 。现在问题已经转化为,计数从 走到 的折线数量,使得折线不穿过直线
。
首先,从 走到 的折线数量为 ,因为总共会走 个右上步。至于如何处理
的限制,将其转化为
“折线不经过直线 “
后,就是经典问题了。我们首先算出总方案数 ,然后,对于每条经过直线
的折线,将其与该直线第一个交点以后的所有部分,沿该直线进行对称。此时这条折线被对应到一条从
出发,在
结束的,没有其它限制的折线!容易证明这个映射构成双射。所以不合法方案数也可以算出来,是
。
综上所述,在 时,答案为
。
接下来考虑
的情况。此时原串要被划分为 个 B
类子序列,所以原串的最后一个字符肯定为 B。而对于前 个字符,问题被转化成了一个 的问题。在把前 个字符划分为 个 类子序列和 个 类子序列后,再把最后一个 B 填到那个
类子序列的末尾即可。也就是说,这个情况可以直接套用 时的结论,令 即可。
预处理组合数即可 算出每个
的答案,对所有 求和即可。时空复杂度 ,这题终于做完了。
代码很好写。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| #include <bits/stdc++.h> using namespace std; const int N = 5e5 + 5,P = 998244353; inline void Plus(int &x,const int &y) { x += y;if(x >= P) x -= P;} int fac[N],inv[N],ifac[N]; inline void init(int n) { fac[0] = 1; for(int i = 1;i <= n;i++) fac[i] = 1ll * fac[i - 1] * i % P; inv[1] = 1; for(int i = 2;i <= n;i++) inv[i] = 1ll * inv[P % i] * (P - P / i) % P; ifac[0] = 1; for(int i = 1;i <= n;i++) ifac[i] = 1ll * ifac[i - 1] * inv[i] % P; } inline int C(int n,int m) { if(n < 0 || m < 0 || n < m) return 0;return 1ll * fac[n] * ifac[m] % P * ifac[n - m] % P;} int n,K; int main() { cin >> n >> K; init(K + 1); int ans = 0; for(int d = 2 - (K & 1);d <= K;d += 2) { if(max(0,n - d) > (K - d) / 2) continue; int t = (K - d) / 2 - max(0,n - d); int tmp = C(K,(K + d) / 2); Plus(tmp,P - C(K,(K - 2 - t - t - d) / 2)); Plus(ans,tmp); } if(!(K & 1) && K / 2 >= n) { int d = 1;--K; int t = (K - d) / 2 - max(0,n - d); int tmp = C(K,(K + d) / 2); Plus(tmp,P - C(K,(K - 2 - t - t - d) / 2)); Plus(ans,tmp); } cout << ans << endl; return 0; }
|