cranelift_codegen/egraph/
cost.rs

1//! Cost functions for egraph representation.
2
3use crate::ir::Opcode;
4
5/// A cost of computing some value in the program.
6///
7/// Costs are measured in an arbitrary union that we represent in a
8/// `u32`. The ordering is meant to be meaningful, but the value of a
9/// single unit is arbitrary (and "not to scale"). We use a collection
10/// of heuristics to try to make this approximation at least usable.
11///
12/// We start by defining costs for each opcode (see `pure_op_cost`
13/// below). The cost of computing some value, initially, is the cost
14/// of its opcode, plus the cost of computing its inputs.
15///
16/// We then adjust the cost according to loop nests: for each
17/// loop-nest level, we multiply by 1024. Because we only have 32
18/// bits, we limit this scaling to a loop-level of two (i.e., multiply
19/// by 2^20 ~= 1M).
20///
21/// Arithmetic on costs is always saturating: we don't want to wrap
22/// around and return to a tiny cost when adding the costs of two very
23/// expensive operations. It is better to approximate and lose some
24/// precision than to lose the ordering by wrapping.
25///
26/// Finally, we reserve the highest value, `u32::MAX`, as a sentinel
27/// that means "infinite". This is separate from the finite costs and
28/// not reachable by doing arithmetic on them (even when overflowing)
29/// -- we saturate just *below* infinity. (This is done by the
30/// `finite()` method.) An infinite cost is used to represent a value
31/// that cannot be computed, or otherwise serve as a sentinel when
32/// performing search for the lowest-cost representation of a value.
33#[derive(Clone, Copy, PartialEq, Eq)]
34pub(crate) struct Cost(u32);
35
36impl core::fmt::Debug for Cost {
37    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
38        if *self == Cost::infinity() {
39            write!(f, "Cost::Infinite")
40        } else {
41            f.debug_struct("Cost::Finite")
42                .field("op_cost", &self.op_cost())
43                .field("depth", &self.depth())
44                .finish()
45        }
46    }
47}
48
49impl Ord for Cost {
50    #[inline]
51    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
52        // We make sure that the high bits are the op cost and the low bits are
53        // the depth. This means that we can use normal integer comparison to
54        // order by op cost and then depth.
55        //
56        // We want to break op cost ties with depth (rather than the other way
57        // around). When the op cost is the same, we prefer shallow and wide
58        // expressions to narrow and deep expressions and breaking ties with
59        // `depth` gives us that. For example, `(a + b) + (c + d)` is preferred
60        // to `((a + b) + c) + d`. This is beneficial because it exposes more
61        // instruction-level parallelism and shortens live ranges.
62        self.0.cmp(&other.0)
63    }
64}
65
66impl PartialOrd for Cost {
67    #[inline]
68    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
69        Some(self.cmp(other))
70    }
71}
72
73impl Cost {
74    const DEPTH_BITS: u8 = 8;
75    const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
76    const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
77    const MAX_OP_COST: u32 = Self::OP_COST_MASK >> Self::DEPTH_BITS;
78
79    pub(crate) fn infinity() -> Cost {
80        // 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
81        // only for heuristics and always saturate so this suffices!)
82        Cost(u32::MAX)
83    }
84
85    pub(crate) fn zero() -> Cost {
86        Cost(0)
87    }
88
89    /// Construct a new `Cost` from the given parts.
90    ///
91    /// If the opcode cost is greater than or equal to the maximum representable
92    /// opcode cost, then the resulting `Cost` saturates to infinity.
93    fn new(opcode_cost: u32, depth: u8) -> Cost {
94        if opcode_cost >= Self::MAX_OP_COST {
95            Self::infinity()
96        } else {
97            Cost(opcode_cost << Self::DEPTH_BITS | u32::from(depth))
98        }
99    }
100
101    fn depth(&self) -> u8 {
102        let depth = self.0 & Self::DEPTH_MASK;
103        u8::try_from(depth).unwrap()
104    }
105
106    fn op_cost(&self) -> u32 {
107        (self.0 & Self::OP_COST_MASK) >> Self::DEPTH_BITS
108    }
109
110    /// Return the cost of an opcode.
111    fn of_opcode(op: Opcode) -> Cost {
112        match op {
113            // Constants.
114            Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1, 0),
115
116            // Extends/reduces.
117            Opcode::Uextend
118            | Opcode::Sextend
119            | Opcode::Ireduce
120            | Opcode::Iconcat
121            | Opcode::Isplit => Cost::new(1, 0),
122
123            // "Simple" arithmetic.
124            Opcode::Iadd
125            | Opcode::Isub
126            | Opcode::Band
127            | Opcode::Bor
128            | Opcode::Bxor
129            | Opcode::Bnot
130            | Opcode::Ishl
131            | Opcode::Ushr
132            | Opcode::Sshr => Cost::new(3, 0),
133
134            // "Expensive" arithmetic.
135            Opcode::Imul => Cost::new(10, 0),
136
137            // Everything else.
138            _ => {
139                // By default, be slightly more expensive than "simple"
140                // arithmetic.
141                let mut c = Cost::new(4, 0);
142
143                // And then get more expensive as the opcode does more side
144                // effects.
145                if op.can_trap() || op.other_side_effects() {
146                    c = c + Cost::new(10, 0);
147                }
148                if op.can_load() {
149                    c = c + Cost::new(20, 0);
150                }
151                if op.can_store() {
152                    c = c + Cost::new(50, 0);
153                }
154
155                c
156            }
157        }
158    }
159
160    /// Compute the cost of the operation and its given operands.
161    ///
162    /// Caller is responsible for checking that the opcode came from an instruction
163    /// that satisfies `inst_predicates::is_pure_for_egraph()`.
164    pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
165        let c = Self::of_opcode(op) + operand_costs.into_iter().sum();
166        Cost::new(c.op_cost(), c.depth().saturating_add(1))
167    }
168
169    /// Compute the cost of an operation in the side-effectful skeleton.
170    pub(crate) fn of_skeleton_op(op: Opcode, arity: usize) -> Self {
171        Cost::of_opcode(op) + Cost::new(u32::try_from(arity).unwrap(), (arity != 0) as _)
172    }
173}
174
175impl std::iter::Sum<Cost> for Cost {
176    fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {
177        iter.fold(Self::zero(), |a, b| a + b)
178    }
179}
180
181impl std::default::Default for Cost {
182    fn default() -> Cost {
183        Cost::zero()
184    }
185}
186
187impl std::ops::Add<Cost> for Cost {
188    type Output = Cost;
189
190    fn add(self, other: Cost) -> Cost {
191        let op_cost = self.op_cost().saturating_add(other.op_cost());
192        let depth = std::cmp::max(self.depth(), other.depth());
193        Cost::new(op_cost, depth)
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn add_cost() {
203        let a = Cost::new(5, 2);
204        let b = Cost::new(37, 3);
205        assert_eq!(a + b, Cost::new(42, 3));
206        assert_eq!(b + a, Cost::new(42, 3));
207    }
208
209    #[test]
210    fn add_infinity() {
211        let a = Cost::new(5, 2);
212        let b = Cost::infinity();
213        assert_eq!(a + b, Cost::infinity());
214        assert_eq!(b + a, Cost::infinity());
215    }
216
217    #[test]
218    fn op_cost_saturates_to_infinity() {
219        let a = Cost::new(Cost::MAX_OP_COST - 10, 2);
220        let b = Cost::new(11, 2);
221        assert_eq!(a + b, Cost::infinity());
222        assert_eq!(b + a, Cost::infinity());
223    }
224
225    #[test]
226    fn depth_saturates_to_max_depth() {
227        let a = Cost::new(10, u8::MAX);
228        let b = Cost::new(10, 1);
229        assert_eq!(
230            Cost::of_pure_op(Opcode::Iconst, [a, b]),
231            Cost::new(21, u8::MAX)
232        );
233        assert_eq!(
234            Cost::of_pure_op(Opcode::Iconst, [b, a]),
235            Cost::new(21, u8::MAX)
236        );
237    }
238}