wasmtime/runtime/vm/
byte_count.rs

1use core::fmt;
2
3use super::host_page_size;
4
5/// A number of bytes that's guaranteed to be aligned to the host page size.
6///
7/// This is used to manage page-aligned memory allocations.
8#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
9pub struct HostAlignedByteCount(
10    // Invariant: this is always a multiple of the host page size.
11    usize,
12);
13
14impl HostAlignedByteCount {
15    /// A zero byte count.
16    pub const ZERO: Self = Self(0);
17
18    /// Creates a new `HostAlignedByteCount` from an aligned byte count.
19    ///
20    /// Returns an error if `bytes` is not page-aligned.
21    pub fn new(bytes: usize) -> Result<Self, ByteCountNotAligned> {
22        let host_page_size = host_page_size();
23        if bytes % host_page_size == 0 {
24            Ok(Self(bytes))
25        } else {
26            Err(ByteCountNotAligned(bytes))
27        }
28    }
29
30    /// Creates a new `HostAlignedByteCount` from an aligned byte count without
31    /// checking validity.
32    ///
33    /// ## Safety
34    ///
35    /// The caller must ensure that `bytes` is page-aligned.
36    pub unsafe fn new_unchecked(bytes: usize) -> Self {
37        debug_assert!(
38            bytes % host_page_size() == 0,
39            "byte count {bytes} is not page-aligned (page size = {})",
40            host_page_size(),
41        );
42        Self(bytes)
43    }
44
45    /// Creates a new `HostAlignedByteCount`, rounding up to the nearest page.
46    ///
47    /// Returns an error if `bytes + page_size - 1` overflows.
48    pub fn new_rounded_up(bytes: usize) -> Result<Self, ByteCountOutOfBounds> {
49        let page_size = host_page_size();
50        debug_assert!(page_size.is_power_of_two());
51        match bytes.checked_add(page_size - 1) {
52            Some(v) => Ok(Self(v & !(page_size - 1))),
53            None => Err(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::RoundUp)),
54        }
55    }
56
57    /// Creates a new `HostAlignedByteCount` from a `u64`, rounding up to the nearest page.
58    ///
59    /// Returns an error if the `u64` overflows `usize`, or if `bytes +
60    /// page_size - 1` overflows.
61    pub fn new_rounded_up_u64(bytes: u64) -> Result<Self, ByteCountOutOfBounds> {
62        let bytes = bytes
63            .try_into()
64            .map_err(|_| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::ConvertU64))?;
65        Self::new_rounded_up(bytes)
66    }
67
68    /// Returns the host page size.
69    pub fn host_page_size() -> HostAlignedByteCount {
70        // The host page size is always a multiple of itself.
71        HostAlignedByteCount(host_page_size())
72    }
73
74    /// Returns true if the page count is zero.
75    #[inline]
76    pub fn is_zero(self) -> bool {
77        self == Self::ZERO
78    }
79
80    /// Returns the number of bytes as a `usize`.
81    #[inline]
82    pub fn byte_count(self) -> usize {
83        self.0
84    }
85
86    /// Add two aligned byte counts together.
87    ///
88    /// Returns an error if the result overflows.
89    pub fn checked_add(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> {
90        // aligned + aligned = aligned
91        self.0
92            .checked_add(bytes.0)
93            .map(Self)
94            .ok_or(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Add))
95    }
96
97    // Note: saturating_add should not be naively added! usize::MAX is not a
98    // power of 2 so is not aligned.
99
100    /// Compute `self - bytes`.
101    ///
102    /// Returns an error if the result underflows.
103    pub fn checked_sub(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> {
104        // aligned - aligned = aligned
105        self.0
106            .checked_sub(bytes.0)
107            .map(Self)
108            .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Sub))
109    }
110
111    /// Compute `self - bytes`, returning zero if the result underflows.
112    #[inline]
113    pub fn saturating_sub(self, bytes: HostAlignedByteCount) -> Self {
114        // aligned - aligned = aligned, and 0 is always aligned.
115        Self(self.0.saturating_sub(bytes.0))
116    }
117
118    /// Multiply an aligned byte count by a scalar value.
119    ///
120    /// Returns an error if the result overflows.
121    pub fn checked_mul(self, scalar: usize) -> Result<Self, ByteCountOutOfBounds> {
122        // aligned * scalar = aligned
123        self.0
124            .checked_mul(scalar)
125            .map(Self)
126            .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Mul))
127    }
128
129    /// Divide an aligned byte count by another aligned byte count, producing a
130    /// scalar value.
131    ///
132    /// Returns an error in case the divisor is zero.
133    pub fn checked_div(self, divisor: HostAlignedByteCount) -> Result<usize, ByteCountOutOfBounds> {
134        self.0
135            .checked_div(divisor.0)
136            .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Div))
137    }
138
139    /// Compute the remainder of an aligned byte count divided by another
140    /// aligned byte count.
141    ///
142    /// The remainder is always an aligned byte count itself.
143    ///
144    /// Returns an error in case the divisor is zero.
145    pub fn checked_rem(self, divisor: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> {
146        // Why is the remainder an aligned byte count? For example, if the page
147        // size is 4KiB, then the remainder of dividing (say) 40KiB by 16KiB is
148        // 8KiB, which is a multiple of 4KiB.
149        //
150        // More generally, for integers n >= 0, m > 0, k > 0:
151        //
152        //     (n * k) % (m * k) = (n % m) * k
153        //
154        // which is a multiple of k. Here, k is the host page size, so the
155        // remainder is a multiple of the host page size.
156        self.0
157            .checked_rem(divisor.0)
158            .map(Self)
159            .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Rem))
160    }
161
162    /// Unchecked multiplication by a scalar value.
163    ///
164    /// ## Safety
165    ///
166    /// The result must not overflow.
167    #[inline]
168    pub unsafe fn unchecked_mul(self, n: usize) -> Self {
169        Self(self.0 * n)
170    }
171}
172
173impl PartialEq<usize> for HostAlignedByteCount {
174    #[inline]
175    fn eq(&self, other: &usize) -> bool {
176        self.0 == *other
177    }
178}
179
180impl PartialEq<HostAlignedByteCount> for usize {
181    #[inline]
182    fn eq(&self, other: &HostAlignedByteCount) -> bool {
183        *self == other.0
184    }
185}
186
187struct LowerHexDisplay<T>(T);
188
189impl<T: fmt::LowerHex> fmt::Display for LowerHexDisplay<T> {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        // Use the LowerHex impl as the Display impl, ensuring that there's
192        // always a 0x in the beginning (i.e. that the alternate formatter is
193        // used.)
194        if f.alternate() {
195            fmt::LowerHex::fmt(&self.0, f)
196        } else {
197            // Unfortunately, fill and alignment aren't respected this way, but
198            // it's quite hard to construct a new formatter with mostly the same
199            // options but the alternate flag set.
200            // https://github.com/rust-lang/rust/pull/118159 would make this
201            // easier.
202            write!(f, "{:#x}", self.0)
203        }
204    }
205}
206
207impl fmt::Display for HostAlignedByteCount {
208    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209        // Use the LowerHex impl as the Display impl, ensuring that there's
210        // always a 0x in the beginning (i.e. that the alternate formatter is
211        // used.)
212        fmt::Display::fmt(&LowerHexDisplay(self.0), f)
213    }
214}
215
216impl fmt::LowerHex for HostAlignedByteCount {
217    #[inline]
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        fmt::LowerHex::fmt(&self.0, f)
220    }
221}
222
223#[derive(Clone, Copy, Debug, PartialEq, Eq)]
224pub struct ByteCountNotAligned(usize);
225
226impl fmt::Display for ByteCountNotAligned {
227    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228        write!(
229            f,
230            "byte count not page-aligned: {}",
231            LowerHexDisplay(self.0)
232        )
233    }
234}
235
236impl core::error::Error for ByteCountNotAligned {}
237
238#[derive(Clone, Copy, Debug, PartialEq, Eq)]
239pub struct ByteCountOutOfBounds(ByteCountOutOfBoundsKind);
240
241impl fmt::Display for ByteCountOutOfBounds {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        write!(f, "{}", self.0)
244    }
245}
246
247impl core::error::Error for ByteCountOutOfBounds {}
248
249#[derive(Clone, Copy, Debug, PartialEq, Eq)]
250enum ByteCountOutOfBoundsKind {
251    // We don't carry the arguments that errored out to avoid the error type
252    // becoming too big.
253    RoundUp,
254    ConvertU64,
255    Add,
256    Sub,
257    Mul,
258    Div,
259    Rem,
260}
261
262impl fmt::Display for ByteCountOutOfBoundsKind {
263    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264        match self {
265            ByteCountOutOfBoundsKind::RoundUp => f.write_str("byte count overflow rounding up"),
266            ByteCountOutOfBoundsKind::ConvertU64 => {
267                f.write_str("byte count overflow converting u64")
268            }
269            ByteCountOutOfBoundsKind::Add => f.write_str("byte count overflow during addition"),
270            ByteCountOutOfBoundsKind::Sub => f.write_str("byte count underflow during subtraction"),
271            ByteCountOutOfBoundsKind::Mul => {
272                f.write_str("byte count overflow during multiplication")
273            }
274            ByteCountOutOfBoundsKind::Div => f.write_str("division by zero"),
275            ByteCountOutOfBoundsKind::Rem => f.write_str("remainder by zero"),
276        }
277    }
278}
279
280#[cfg(test)]
281mod proptest_impls {
282    use super::*;
283
284    use proptest::prelude::*;
285
286    impl Arbitrary for HostAlignedByteCount {
287        type Strategy = BoxedStrategy<Self>;
288        type Parameters = ();
289
290        fn arbitrary_with(_: ()) -> Self::Strategy {
291            // Compute the number of pages that fit in a usize, rounded down.
292            // For example, if:
293            //
294            // * usize::MAX is 2**64 - 1
295            // * host_page_size is 2**12 (4KiB)
296            //
297            // Then page_count = floor(usize::MAX / host_page_size) = 2**52 - 1.
298            // The range 0..=page_count, when multiplied by the page size, will
299            // produce values in the range 0..=(2**64 - 2**12), in steps of
300            // 2**12, uniformly at random. This is the desired uniform
301            // distribution of byte counts.
302            let page_count = usize::MAX / host_page_size();
303            (0..=page_count)
304                .prop_map(|n| HostAlignedByteCount::new(n * host_page_size()).unwrap())
305                .boxed()
306        }
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn byte_count_display() {
316        // Pages should hopefully be 64k or smaller.
317        let byte_count = HostAlignedByteCount::new(65536).unwrap();
318
319        assert_eq!(format!("{byte_count}"), "0x10000");
320        assert_eq!(format!("{byte_count:x}"), "10000");
321        assert_eq!(format!("{byte_count:#x}"), "0x10000");
322    }
323
324    #[test]
325    fn byte_count_ops() {
326        let host_page_size = host_page_size();
327        HostAlignedByteCount::new(0).expect("0 is aligned");
328        HostAlignedByteCount::new(host_page_size).expect("host_page_size is aligned");
329        HostAlignedByteCount::new(host_page_size * 2).expect("host_page_size * 2 is aligned");
330        HostAlignedByteCount::new(host_page_size + 1)
331            .expect_err("host_page_size + 1 is not aligned");
332        HostAlignedByteCount::new(host_page_size / 2)
333            .expect_err("host_page_size / 2 is not aligned");
334
335        // Rounding up.
336        HostAlignedByteCount::new_rounded_up(usize::MAX).expect_err("usize::MAX overflows");
337        assert_eq!(
338            HostAlignedByteCount::new_rounded_up(usize::MAX - host_page_size)
339                .expect("(usize::MAX - 1 page) is in bounds"),
340            HostAlignedByteCount::new((usize::MAX - host_page_size) + 1)
341                .expect("usize::MAX is 2**N - 1"),
342        );
343
344        // Addition.
345        let half_max = HostAlignedByteCount::new((usize::MAX >> 1) + 1)
346            .expect("(usize::MAX >> 1) + 1 is aligned");
347        half_max
348            .checked_add(HostAlignedByteCount::host_page_size())
349            .expect("half max + page size is in bounds");
350        half_max
351            .checked_add(half_max)
352            .expect_err("half max + half max is out of bounds");
353
354        // Subtraction.
355        let half_max_minus_one = half_max
356            .checked_sub(HostAlignedByteCount::host_page_size())
357            .expect("(half_max - 1 page) is in bounds");
358        assert_eq!(
359            half_max.checked_sub(half_max),
360            Ok(HostAlignedByteCount::ZERO)
361        );
362        assert_eq!(
363            half_max.checked_sub(half_max_minus_one),
364            Ok(HostAlignedByteCount::host_page_size())
365        );
366        half_max_minus_one
367            .checked_sub(half_max)
368            .expect_err("(half_max - 1 page) - half_max is out of bounds");
369
370        // Multiplication.
371        half_max
372            .checked_mul(2)
373            .expect_err("half max * 2 is out of bounds");
374        half_max_minus_one
375            .checked_mul(2)
376            .expect("(half max - 1 page) * 2 is in bounds");
377    }
378}