「清华集训 2017」某位歌姬的故事(动态规划)

题目大意

「清华集训 2017」某位歌姬的故事(UOJ 346)

求满足下列条件的,长度为 $n$ 的正整数序列 $a$ 数量 $\bmod 998244353$ 的结果:

  • $\forall a_i \le A$
  • $\forall i \in [1, Q], \max \{ a_{l_i}, a_{l_i + 1}, \cdots, a_{r_i} \} = m_i$

数据范围:$n, A \le 9 \times 10 ^ 8, Q \le 500$。

思路分析

先将序列离散化。对于离散化后的每一段,处理出这段可能达到的最大值。

考虑将限制按照最大值分组,最大值相同的限制一起处理,最后将每组的答案相乘得到总答案。其正确性是因为对于每段区间,它只会对包含它的限制的最小值(等于这段可能达到的最大值)贡献,所以贡献是不重不漏的。

于是问题就转化成了求满足下列条件的,长度为 $n’$ 的正整数序列 $a’$ 数量 $\bmod 998244353$ 的结果:

  • $\forall a’_i \le A’$
  • $\forall i \in [1, Q’], \max \{ a’_{l’_i}, a’_{l’_i + 1}, \cdots, a’_{r’_i} \} = m’$

可以使用 $\text{DP}$ 的方法来求解该问题。令 $\text{len}_i$ 表示第 $i$ 段的长度,预处理 $\text{mn}_i$ 表示右端点为第 $i$ 段区间的限制中左端点所在段的最小值。令 $\text{dp}_{i, j}$ 表示考虑到第 $i$ 位,最后一个 $A ^ {\prime}$ 在第 $j$ 段上的方案数。有两种转移:

  • $\text{dp}_{i, j} \leftarrow \text{dp}_{i - 1, j} \times (A ^ {\prime} - 1) ^ {\text{len}_i} \ (j \in [\text{mn}_i, i - 1])$
  • $\text{dp}_{i, i} \leftarrow \text{dp}_{i - 1, j} \times ((A ^ {\prime}) ^ {\text{len}_i} - (A ^ {\prime} - 1) ^ {\text{len}_i}) \ (j \in [0, i - 1])$

总时间复杂度 $O(T \times Q^2 \times \log n)$。

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include <cstdio>
#include <cstring>
#include <set>
#include <vector>
#include <algorithm>
using namespace std;

const int maxn = 500, maxm = 2 * maxn, mod = 998244353;
bool vis[maxn + 3], ok[maxm + 3];
int T, n, q, A, l[maxn + 3], r[maxn + 3], a[maxn + 3], m, pos[maxm + 3];
int M, L[maxm + 3], R[maxm + 3], mx[maxm + 3], mn[maxm + 3], Q, tm[maxn + 3];
int dp[maxm + 3][maxm + 3];
vector<int> V[maxm + 3];
multiset<int> S;

int Pow(int a, int b) {
int c = 1;
for (; b; b >>= 1, a = 1ll * a * a % mod) {
if (b & 1) c = 1ll * a * c % mod;
}
return c;
}

int solve(int n, int w) {
// dp[i][j] 表示前 i 位的最后一个当前最大值在 j 的方案数目
dp[0][0] = 1;
for (int i = 1, k = pos[1]; i <= n; i++, k = pos[i]) {
for (int j = 0; j <= i; j++) dp[i][j] = 0;
int x = Pow(w - 1, R[k] - L[k] + 1), y = Pow(w, R[k] - L[k] + 1) - x;
y < 0 ? y += mod : 0;
for (int j = 0; j < i; j++) if (dp[i - 1][j]) {
if (j >= mn[i]) dp[i][j] = (dp[i][j] + 1ll * x * dp[i - 1][j]) % mod;
dp[i][i] = (dp[i][i] + 1ll * y * dp[i - 1][j]) % mod;
}
}
int res = 0;
for (int i = 0; i <= n; i++) {
res += dp[n][i], res < mod ? 0 : res -= mod;
}
return res;
}

int main() {
scanf("%d", &T);
while (T--) {
scanf("%d %d %d", &n, &q, &A);
m = 0, pos[++m] = 1, pos[++m] = n + 1;
for (int i = 1; i <= q; i++) {
scanf("%d %d %d", &l[i], &r[i], &a[i]);
pos[++m] = l[i], pos[++m] = r[i] + 1;
}
sort(pos + 1, pos + m + 1);
m = unique(pos + 1, pos + m + 1) - (pos + 1);
M = m - 1;
for (int i = 1; i <= M; i++) {
L[i] = pos[i], R[i] = pos[i + 1] - 1;
V[i].clear();
}
for (int i = 1; i <= q; i++) {
l[i] = lower_bound(pos + 1, pos + m + 1, l[i]) - pos;
r[i] = upper_bound(pos + 1, pos + m + 1, r[i]) - (pos + 1);
V[l[i]].push_back(i), V[r[i] + 1].push_back(i);
}
memset(vis, false, sizeof(vis));
S.clear();
for (int i = 1; i <= M; i++) {
for (int k: V[i]) {
if (!vis[k]) {
vis[k] = true;
S.insert(a[k]);
} else {
S.erase(S.lower_bound(a[k]));
}
}
mx[i] = S.empty() ? -1 : *S.begin();
}
Q = 0;
for (int i = 1; i <= q; i++) {
tm[++Q] = a[i];
}
sort(tm + 1, tm + Q + 1);
Q = unique(tm + 1, tm + Q + 1) - (tm + 1);
int ans = 1;
bool flag = true;
for (int i = 1; i <= Q; i++) {
m = 0;
for (int j = 1; j <= M; j++) {
if (tm[i] == mx[j]) pos[++m] = j;
}
for (int j = 1; j <= m; j++) mn[j] = -1;
for (int j = 1; j <= q; j++) {
if (tm[i] == a[j]) {
if (!m) { flag = false; break; }
l[j] = lower_bound(pos + 1, pos + m + 1, l[j]) - pos;
r[j] = upper_bound(pos + 1, pos + m + 1, r[j]) - (pos + 1);
mn[r[j]] = max(mn[r[j]], l[j]);
}
}
if (!flag) break;
ans = 1ll * ans * solve(m, tm[i]) % mod;
}
if (!flag) { puts("0"); continue; }
for (int i = 1; i <= M; i++) {
if (mx[i] == -1) ans = 1ll * ans * Pow(A, R[i] - L[i] + 1) % mod;
}
printf("%d\n", ans);
}
return 0;
}