cranelift_codegen/unionfind.rs
1//! Simple union-find data structure.
2
3use crate::trace;
4use cranelift_entity::{packed_option::ReservedValue, EntityRef, SecondaryMap};
5use std::hash::Hash;
6use std::mem::swap;
7
8/// A union-find data structure. The data structure can allocate
9/// `Idx`s, indicating eclasses, and can merge eclasses together.
10///
11/// Running `union(a, b)` will change the canonical `Idx` of `a` or `b`.
12/// Usually, this is chosen based on what will minimize path lengths,
13/// but it is also possible to _pin_ an eclass, such that its canonical `Idx`
14/// won't change unless it gets unioned with another pinned eclass.
15///
16/// In the context of the egraph pass, merging two pinned eclasses
17/// is very unlikely to happen – we do not know a single concrete test case
18/// where it does. The only situation where it might happen looks as follows:
19///
20/// 1. We encounter terms `A` and `B`, and the optimizer does not find any
21/// reason to union them together.
22/// 2. We encounter a term `C`, and we rewrite `C -> A`, and separately, `C -> B`.
23///
24/// Unless `C` somehow includes some crucial hint without which it is hard to
25/// notice that `A = B`, there's probably a rewrite rule that we should add.
26///
27/// Worst case, if we do merge two pinned eclasses, some nodes will essentially
28/// disappear from the GVN map, which only affects the quality of the generated
29/// code.
30#[derive(Clone, Debug, PartialEq)]
31pub struct UnionFind<Idx: EntityRef> {
32 parent: SecondaryMap<Idx, Val<Idx>>,
33 /// The `rank` table is used to perform the union operations optimally,
34 /// without creating unnecessarily long paths. Pins are represented by
35 /// eclasses with a rank of `u8::MAX`.
36 ///
37 /// `rank[x]` is the upper bound on the height of the subtree rooted at `x`.
38 /// The subtree is guaranteed to have at least `2**rank[x]` elements,
39 /// unless `rank` has been artificially inflated by pinning.
40 rank: SecondaryMap<Idx, u8>,
41
42 pub(crate) pinned_union_count: u64,
43}
44
45#[derive(Clone, Debug, PartialEq)]
46struct Val<Idx>(Idx);
47
48impl<Idx: EntityRef + ReservedValue> Default for Val<Idx> {
49 fn default() -> Self {
50 Self(Idx::reserved_value())
51 }
52}
53
54impl<Idx: EntityRef + Hash + std::fmt::Display + Ord + ReservedValue> UnionFind<Idx> {
55 /// Create a new `UnionFind` with the given capacity.
56 pub fn with_capacity(cap: usize) -> Self {
57 UnionFind {
58 parent: SecondaryMap::with_capacity(cap),
59 rank: SecondaryMap::with_capacity(cap),
60 pinned_union_count: 0,
61 }
62 }
63
64 /// Add an `Idx` to the `UnionFind`, with its own equivalence class
65 /// initially. All `Idx`s must be added before being queried or
66 /// unioned.
67 pub fn add(&mut self, id: Idx) {
68 debug_assert!(id != Idx::reserved_value());
69 self.parent[id] = Val(id);
70 }
71
72 /// Find the canonical `Idx` of a given `Idx`.
73 pub fn find(&self, mut node: Idx) -> Idx {
74 while node != self.parent[node].0 {
75 node = self.parent[node].0;
76 }
77 node
78 }
79
80 /// Find the canonical `Idx` of a given `Idx`, updating the data
81 /// structure in the process so that future queries for this `Idx`
82 /// (and others in its chain up to the root of the equivalence
83 /// class) will be faster.
84 pub fn find_and_update(&mut self, mut node: Idx) -> Idx {
85 // "Path halving" mutating find (Tarjan and Van Leeuwen).
86 debug_assert!(node != Idx::reserved_value());
87 while node != self.parent[node].0 {
88 let next = self.parent[self.parent[node].0].0;
89 debug_assert!(next != Idx::reserved_value());
90 self.parent[node] = Val(next);
91 node = next;
92 }
93 debug_assert!(node != Idx::reserved_value());
94 node
95 }
96
97 /// Request a stable identifier for `node`.
98 ///
99 /// After an `union` operation, the canonical representative of one
100 /// of the eclasses being merged together necessarily changes. If a pinned
101 /// eclass is merged with a non-pinned eclass, it'll be the other eclass
102 /// whose representative will change.
103 ///
104 /// If two pinned eclasses are unioned, one of the pins gets broken,
105 /// which is reported in the statistics for the pass. No concrete test case
106 /// which triggers this is known.
107 pub fn pin_index(&mut self, mut node: Idx) -> Idx {
108 node = self.find_and_update(node);
109 self.rank[node] = u8::MAX;
110 node
111 }
112
113 /// Merge the equivalence classes of the two `Idx`s.
114 pub fn union(&mut self, a: Idx, b: Idx) {
115 let mut a = self.find_and_update(a);
116 let mut b = self.find_and_update(b);
117
118 if a == b {
119 return;
120 }
121
122 if self.rank[a] < self.rank[b] {
123 swap(&mut a, &mut b);
124 } else if self.rank[a] == self.rank[b] {
125 self.rank[a] = self.rank[a].checked_add(1).unwrap_or_else(
126 #[cold]
127 || {
128 // Both `a` and `b` are pinned.
129 //
130 // This should only occur if we rewrite X -> Y and X -> Z,
131 // yet neither Y -> Z nor Z -> Y can be established without
132 // the "hint" provided by X. This probably means we're
133 // missing an optimization rule.
134 self.pinned_union_count += 1;
135 u8::MAX
136 },
137 );
138 }
139
140 self.parent[b] = Val(a);
141 trace!("union: {}, {}", a, b);
142 }
143}