Skip to main content

cranelift_codegen/egraph/
cost.rs

1//! Cost functions for egraph representation.
2
3use crate::ir::Opcode;
4
5/// A cost of computing some value in the program.
6///
7/// Costs are measured in an arbitrary union that we represent in a
8/// `u32`. The ordering is meant to be meaningful, but the value of a
9/// single unit is arbitrary (and "not to scale"). We use a collection
10/// of heuristics to try to make this approximation at least usable.
11///
12/// We start by defining costs for each opcode (see `pure_op_cost`
13/// below). The cost of computing some value, initially, is the cost
14/// of its opcode, plus the cost of computing its inputs.
15///
16/// We then adjust the cost according to loop nests: for each
17/// loop-nest level, we multiply by 1024. Because we only have 32
18/// bits, we limit this scaling to a loop-level of two (i.e., multiply
19/// by 2^20 ~= 1M).
20///
21/// Arithmetic on costs is always saturating: we don't want to wrap
22/// around and return to a tiny cost when adding the costs of two very
23/// expensive operations. It is better to approximate and lose some
24/// precision than to lose the ordering by wrapping.
25///
26/// Finally, we reserve the highest value, `u32::MAX`, as a sentinel
27/// that means "infinite". This is separate from the finite costs and
28/// not reachable by doing arithmetic on them (even when overflowing)
29/// -- we saturate just *below* infinity. (This is done by the
30/// `finite()` method.) An infinite cost is used to represent a value
31/// that cannot be computed, or otherwise serve as a sentinel when
32/// performing search for the lowest-cost representation of a value.
33#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
34pub(crate) struct Cost(u32);
35
36impl core::fmt::Debug for Cost {
37    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
38        if *self == Cost::infinity() {
39            write!(f, "Cost::Infinite")
40        } else {
41            f.debug_tuple("Cost::Finite").field(&self.cost()).finish()
42        }
43    }
44}
45
46impl Cost {
47    pub(crate) fn infinity() -> Cost {
48        // 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
49        // only for heuristics and always saturate so this suffices!)
50        Cost(u32::MAX)
51    }
52
53    pub(crate) fn zero() -> Cost {
54        Cost(0)
55    }
56
57    /// Construct a new `Cost`.
58    fn new(cost: u32) -> Cost {
59        Cost(cost)
60    }
61
62    fn cost(&self) -> u32 {
63        self.0
64    }
65
66    /// Return the cost of an opcode.
67    fn of_opcode(op: Opcode) -> Cost {
68        match op {
69            // Constants.
70            Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1),
71
72            // Extends/reduces.
73            Opcode::Uextend
74            | Opcode::Sextend
75            | Opcode::Ireduce
76            | Opcode::Iconcat
77            | Opcode::Isplit => Cost::new(1),
78
79            // "Simple" arithmetic.
80            Opcode::Iadd
81            | Opcode::Isub
82            | Opcode::Band
83            | Opcode::Bor
84            | Opcode::Bxor
85            | Opcode::Bnot
86            | Opcode::Ishl
87            | Opcode::Ushr
88            | Opcode::Sshr => Cost::new(3),
89
90            // "Expensive" arithmetic.
91            Opcode::Imul => Cost::new(10),
92
93            // Everything else.
94            _ => {
95                // By default, be slightly more expensive than "simple"
96                // arithmetic.
97                let mut c = Cost::new(4);
98
99                // And then get more expensive as the opcode does more side
100                // effects.
101                if op.can_trap() || op.other_side_effects() {
102                    c = c + Cost::new(10);
103                }
104                if op.can_load() {
105                    c = c + Cost::new(20);
106                }
107                if op.can_store() {
108                    c = c + Cost::new(50);
109                }
110
111                c
112            }
113        }
114    }
115
116    /// Compute the cost of the operation and its given operands.
117    ///
118    /// Caller is responsible for checking that the opcode came from an instruction
119    /// that satisfies `inst_predicates::is_pure_for_egraph()`.
120    pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
121        let c = Self::of_opcode(op) + operand_costs.into_iter().sum();
122        Cost::new(c.cost())
123    }
124
125    /// Compute the cost of an operation in the side-effectful skeleton.
126    pub(crate) fn of_skeleton_op(op: Opcode, arity: usize) -> Self {
127        Cost::of_opcode(op) + Cost::new(u32::try_from(arity).unwrap())
128    }
129}
130
131impl core::iter::Sum<Cost> for Cost {
132    fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {
133        iter.fold(Self::zero(), |a, b| a + b)
134    }
135}
136
137impl core::default::Default for Cost {
138    fn default() -> Cost {
139        Cost::zero()
140    }
141}
142
143impl core::ops::Add<Cost> for Cost {
144    type Output = Cost;
145
146    fn add(self, other: Cost) -> Cost {
147        Cost::new(self.cost().saturating_add(other.cost()))
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn add_cost() {
157        let a = Cost::new(5);
158        let b = Cost::new(37);
159        assert_eq!(a + b, Cost::new(42));
160        assert_eq!(b + a, Cost::new(42));
161    }
162
163    #[test]
164    fn add_infinity() {
165        let a = Cost::new(5);
166        let b = Cost::infinity();
167        assert_eq!(a + b, Cost::infinity());
168        assert_eq!(b + a, Cost::infinity());
169    }
170
171    #[test]
172    fn op_cost_saturates_to_infinity() {
173        let a = Cost::new(u32::MAX - 10);
174        let b = Cost::new(11);
175        assert_eq!(a + b, Cost::infinity());
176        assert_eq!(b + a, Cost::infinity());
177    }
178}