「学习笔记」同余最短路

$\Large\mathcal{Problem}$

$\qquad$ 给定一个 $a$ 数组,问在 $[l, r]$ 范围内,有多少个 $k$ 满足 $\sum\limits_{i=1}^n a_i \times x_i = k$ 存在非负整数解。

$\Large\mathcal{Solution}$

$\qquad$ 在 $l, r$ 比较小的情况下,我们可以用完全背包来实现。

$\qquad$ 但在 $l, r$ 超出 $int$ 范围时,即使用 $bitset$ 优化后的完全背包也难以通过。

$\qquad$ 这时候就要用上我们的主角 $—$ 同余最短路

$\Large\texttt{实现}$

$\qquad$ 我们令 $Base$ 为某一个 $a_i$。

$\qquad$ $f_i$ 表示满足条件且 $mod$ $Base$ $=$ $i$ 的最小的数

$\qquad$ 那么 $k$ 满足条件当且仅当 $k \ge f_{k \% Base}$

$\qquad$ 建边的话就对于 $i \in [0, Base - 1]$ 向 $(i + a_j)$ $\%$ $Base$ 连一条长度为 $a_j$ 的边。

$\qquad$ (可以发现,边数的多少是依托于点数的多少的,因此 $Base$ 选用 $\min\limits_{i = 1}^n a_i$ 时最优)

$\qquad$ 然后就直接跑最短路即可。

$\qquad$ 据说,由于这种建边方式,$SPFA$ 不会被卡

$\Large\texttt{模板}$

$\qquad$ $luogu$ $P3403$ 跳楼机

$\qquad$ 给定 $a, b, c, h$, 问在 $[0, h - 1]$ 范围内,有多少个 $k$ 满足 $ax + by + cz = k$ 存在非负整数解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cstring>

template <typename Tp>
inline void read(Tp &x) {
x = 0;
bool f = true; char ch = getchar();
for ( ; ch < '0' || ch > '9'; ch = getchar()) f ^= ch == '-';
for ( ; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + (ch ^ 48);
x = f ? x : -x;
}

const int N = 1e5 + 7;

int Head[N], Cnt = 0;

struct Edge {
int to, nxt, val;
};
Edge e[N << 1];

inline void Add_edge(int x, int y, int k) {
e[++Cnt] = (Edge) {y, Head[x], k}, Head[x] = Cnt;
}

struct Node {
int id;
long long val;
};
std::priority_queue <Node> Q;
inline bool operator < (Node x, Node y) {
return x.val > y.val;
}

long long Dis[N];
bool Vis[N];

inline void Dij() {
memset(Dis, 0x3f, sizeof(Dis));
Dis[1] = 1;
Q.push( (Node) {1, 1});
while (Q.size()) {
int u = Q.top().id; long long Val = Q.top().val; Q.pop();
if (Vis[u]) continue;
Vis[u] = true;
for (int i = Head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (Dis[v] > Val + e[i].val) {
Dis[v] = Val + e[i].val;
Q.push( (Node) {v, Dis[v]});
}
}
}
}

int main() {
long long n; int x, y, z;
read(n), read(x), read(y), read(z);
if (x > y) std::swap(x, y);
if (x > z) std::swap(x, z);
if (x == 1) return printf("%lld\n", n), 0;
for (int i = 0; i < x; ++i) {
Add_edge(i, (i + y) % x, y);
Add_edge(i, (i + z) % x, z);
}
Dij();
long long Ans = 0;
for (int i = 0; i < x; ++i) if (Dis[i] <= n) Ans += std::max((n - Dis[i]) / x + 1, 0LL);
printf("%lld\n", Ans);
return 0;
}

$\Large\texttt{例题}$