wasmtime/runtime/component/concurrent/futures_and_streams/
buffers.rs

1#[cfg(feature = "component-model-async-bytes")]
2use bytes::{Bytes, BytesMut};
3#[cfg(feature = "component-model-async-bytes")]
4use std::io::Cursor;
5use std::mem::{self, MaybeUninit};
6use std::slice;
7use std::vec::Vec;
8
9// Inner module here to restrict possible readers of the fields of
10// `UntypedWriteBuffer`.
11pub use untyped::*;
12mod untyped {
13    use super::WriteBuffer;
14    use std::any::TypeId;
15    use std::marker;
16    use std::mem;
17
18    /// Helper structure to type-erase the `T` in `WriteBuffer<T>`.
19    ///
20    /// This is constructed with a `&mut dyn WriteBuffer<T>` and then can only
21    /// be viewed as `&mut dyn WriteBuffer<T>` as well. The `T`, however, is
22    /// carried through methods rather than the struct itself.
23    ///
24    /// Note that this structure has a lifetime `'a` which forces an active
25    /// borrow on the original buffer passed in.
26    pub struct UntypedWriteBuffer<'a> {
27        element_type_id: TypeId,
28        buf: *mut dyn WriteBuffer<()>,
29        _marker: marker::PhantomData<&'a mut dyn WriteBuffer<()>>,
30    }
31
32    /// Helper structure to transmute between `WriteBuffer<T>` and
33    /// `WriteBuffer<()>`.
34    union ReinterpretWriteBuffer<T> {
35        typed: *mut dyn WriteBuffer<T>,
36        untyped: *mut dyn WriteBuffer<()>,
37    }
38
39    impl<'a> UntypedWriteBuffer<'a> {
40        /// Creates a new `UntypedWriteBuffer` from the `buf` provided.
41        ///
42        /// The returned value can be used with the `get_mut` method to get the
43        /// original write buffer back.
44        pub fn new<T: 'static>(buf: &'a mut dyn WriteBuffer<T>) -> UntypedWriteBuffer<'a> {
45            UntypedWriteBuffer {
46                element_type_id: TypeId::of::<T>(),
47                // SAFETY: this is `unsafe` due to reading union fields. That
48                // is safe here because `typed` and `untyped` have the same size
49                // and we're otherwise reinterpreting a raw pointer with a type
50                // parameter to one without one.
51                buf: unsafe {
52                    let r = ReinterpretWriteBuffer { typed: buf };
53                    assert_eq!(mem::size_of_val(&r.typed), mem::size_of_val(&r.untyped));
54                    r.untyped
55                },
56                _marker: marker::PhantomData,
57            }
58        }
59
60        /// Acquires the underyling `WriteBuffer<T>` this was created with.
61        ///
62        /// # Panics
63        ///
64        /// Panics if `T` does not match the type that this was created with.
65        pub fn get_mut<T: 'static>(&mut self) -> &mut dyn WriteBuffer<T> {
66            assert_eq!(self.element_type_id, TypeId::of::<T>());
67            // SAFETY: the `T` has been checked with `TypeId` and this
68            // structure also is proof of valid existence of the original
69            // `&mut WriteBuffer<T>`, so taking the raw pointer back to a safe
70            // reference is valid.
71            unsafe { &mut *ReinterpretWriteBuffer { untyped: self.buf }.typed }
72        }
73    }
74}
75
76/// Trait representing a buffer which may be written to a `StreamWriter`.
77///
78/// See also [`crate::component::Instance::stream`].
79///
80/// # Unsafety
81///
82/// This trait is unsafe due to the contract of the `take` function. This trait
83/// is only safe to implement if the `take` function is implemented correctly,
84/// namely that all the items passed to the closure are fully initialized for
85/// `T`.
86pub unsafe trait WriteBuffer<T>: Send + Sync + 'static {
87    /// Slice of items remaining to be read.
88    fn remaining(&self) -> &[T];
89
90    /// Skip and drop the specified number of items.
91    fn skip(&mut self, count: usize);
92
93    /// Take ownership of the specified number of items.
94    ///
95    /// This function will take `count` items from `self` and pass them as a
96    /// contiguous slice to the closure `fun` provided. The `fun` closure may
97    /// assume that the items are all fully initialized and available to read.
98    /// It is expected that `fun` will read all the items provided. Any items
99    /// that aren't read by `fun` will be leaked.
100    ///
101    /// # Panics
102    ///
103    /// Panics if `count` is larger than `self.remaining()`. If `fun` panics
104    /// then items may be leaked.
105    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<T>]));
106}
107
108/// Trait representing a buffer which may be used to read from a `StreamReader`.
109///
110/// See also [`crate::component::Instance::stream`].
111pub trait ReadBuffer<T>: Send + Sync + 'static {
112    /// Move the specified items into this buffer.
113    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I);
114
115    /// Number of items which may be read before this buffer is full.
116    fn remaining_capacity(&self) -> usize;
117
118    /// Move (i.e. take ownership of) the specified items into this buffer.
119    ///
120    /// This method will drain `count` items from the `input` provided and move
121    /// ownership into this buffer.
122    ///
123    /// # Panics
124    ///
125    /// This method will panic if `count` is larger than
126    /// `self.remaining_capacity()` or if it's larger than `input.remaining()`.
127    fn move_from(&mut self, input: &mut dyn WriteBuffer<T>, count: usize);
128}
129
130pub(super) struct Extender<'a, B>(pub(super) &'a mut B);
131
132impl<T, B: ReadBuffer<T>> Extend<T> for Extender<'_, B> {
133    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
134        self.0.extend(iter)
135    }
136}
137
138// SAFETY: the `take` implementation below guarantees that the `fun` closure is
139// provided with fully initialized items.
140unsafe impl<T: Send + Sync + 'static> WriteBuffer<T> for Option<T> {
141    fn remaining(&self) -> &[T] {
142        if let Some(me) = self {
143            slice::from_ref(me)
144        } else {
145            &[]
146        }
147    }
148
149    fn skip(&mut self, count: usize) {
150        match count {
151            0 => {}
152            1 => {
153                assert!(self.is_some());
154                *self = None;
155            }
156            _ => panic!("cannot skip more than {} item(s)", self.remaining().len()),
157        }
158    }
159
160    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<T>])) {
161        match count {
162            0 => fun(&mut []),
163            1 => {
164                let mut item = MaybeUninit::new(self.take().unwrap());
165                fun(slice::from_mut(&mut item));
166            }
167            _ => panic!("cannot forget more than {} item(s)", self.remaining().len()),
168        }
169    }
170}
171
172impl<T: Send + Sync + 'static> ReadBuffer<T> for Option<T> {
173    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
174        let mut iter = iter.into_iter();
175        if self.is_none() {
176            *self = iter.next();
177        }
178        assert!(iter.next().is_none());
179    }
180
181    fn remaining_capacity(&self) -> usize {
182        if self.is_some() { 0 } else { 1 }
183    }
184
185    fn move_from(&mut self, input: &mut dyn WriteBuffer<T>, count: usize) {
186        match count {
187            0 => {}
188            1 => {
189                assert!(self.is_none());
190                input.take(1, &mut |slice| {
191                    // SAFETY: Per the `WriteBuffer` trait contract this block
192                    // has ownership of the items in `slice` and they're all
193                    // valid to take.
194                    unsafe {
195                        *self = Some(slice[0].assume_init_read());
196                    }
197                });
198            }
199            _ => panic!(
200                "cannot take more than {} item(s)",
201                self.remaining_capacity()
202            ),
203        }
204    }
205}
206
207/// A `WriteBuffer` implementation, backed by a `Vec`.
208pub struct VecBuffer<T> {
209    buffer: Vec<MaybeUninit<T>>,
210    offset: usize,
211}
212
213impl<T> VecBuffer<T> {
214    /// Create a new instance with the specified capacity.
215    pub fn with_capacity(capacity: usize) -> Self {
216        Self {
217            buffer: Vec::with_capacity(capacity),
218            offset: 0,
219        }
220    }
221
222    /// Reset the state of this buffer, removing all items and preserving its
223    /// capacity.
224    pub fn reset(&mut self) {
225        self.skip_(self.remaining_().len());
226        self.buffer.clear();
227        self.offset = 0;
228    }
229
230    fn remaining_(&self) -> &[T] {
231        // SAFETY: This relies on the invariant (upheld in the other methods of
232        // this type) that all the elements from `self.offset` onward are
233        // initialized and valid for `self.buffer`.
234        unsafe { mem::transmute::<&[MaybeUninit<T>], &[T]>(&self.buffer[self.offset..]) }
235    }
236
237    fn skip_(&mut self, count: usize) {
238        assert!(count <= self.remaining_().len());
239        for item in &mut self.buffer[self.offset..][..count] {
240            // Note that the offset is incremented first here to ensure that if
241            // any destructors panic we don't attempt to re-drop the item.
242            self.offset += 1;
243            // SAFETY: See comment in `Self::remaining`
244            unsafe {
245                item.assume_init_drop();
246            }
247        }
248    }
249}
250
251// SAFETY: the `take` implementation below guarantees that the `fun` closure is
252// provided with fully initialized items due to `self.offset`-and-onwards being
253// always initialized.
254unsafe impl<T: Send + Sync + 'static> WriteBuffer<T> for VecBuffer<T> {
255    fn remaining(&self) -> &[T] {
256        self.remaining_()
257    }
258
259    fn skip(&mut self, count: usize) {
260        self.skip_(count)
261    }
262
263    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<T>])) {
264        assert!(count <= self.remaining().len());
265        // Note that the offset here is incremented before `fun` is called to
266        // ensure that if `fun` panics that the items are still considered
267        // transferred.
268        self.offset += count;
269        fun(&mut self.buffer[self.offset - count..]);
270    }
271}
272
273impl<T> From<Vec<T>> for VecBuffer<T> {
274    fn from(buffer: Vec<T>) -> Self {
275        Self {
276            // SAFETY: Transmuting from `Vec<T>` to `Vec<MaybeUninit<T>>` should
277            // be sound for any `T`.
278            buffer: unsafe { mem::transmute::<Vec<T>, Vec<MaybeUninit<T>>>(buffer) },
279            offset: 0,
280        }
281    }
282}
283
284impl<T> Drop for VecBuffer<T> {
285    fn drop(&mut self) {
286        self.reset();
287    }
288}
289
290impl<T: Send + Sync + 'static> ReadBuffer<T> for Vec<T> {
291    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
292        Extend::extend(self, iter)
293    }
294
295    fn remaining_capacity(&self) -> usize {
296        self.capacity().checked_sub(self.len()).unwrap()
297    }
298
299    fn move_from(&mut self, input: &mut dyn WriteBuffer<T>, count: usize) {
300        assert!(count <= self.remaining_capacity());
301        input.take(count, &mut |slice| {
302            for item in slice {
303                // SAFETY: Per the `WriteBuffer` implementation contract this
304                // function has exclusive ownership of all items in `slice` so
305                // this is safe to take and transfer them here.
306                self.push(unsafe { item.assume_init_read() });
307            }
308        });
309    }
310}
311
312// SAFETY: the `take` implementation below guarantees that the `fun` closure is
313// provided with fully initialized items.
314#[cfg(feature = "component-model-async-bytes")]
315unsafe impl WriteBuffer<u8> for Cursor<Bytes> {
316    fn remaining(&self) -> &[u8] {
317        &self.get_ref()[usize::try_from(self.position()).unwrap()..]
318    }
319
320    fn skip(&mut self, count: usize) {
321        assert!(
322            count <= self.remaining().len(),
323            "tried to skip {count} with {} remaining",
324            self.remaining().len()
325        );
326        self.set_position(
327            self.position()
328                .checked_add(u64::try_from(count).unwrap())
329                .unwrap(),
330        );
331    }
332
333    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<u8>])) {
334        assert!(count <= self.remaining().len());
335        fun(unsafe_byte_slice(self.remaining()));
336        self.skip(count);
337    }
338}
339
340// SAFETY: the `take` implementation below guarantees that the `fun` closure is
341// provided with fully initialized items.
342#[cfg(feature = "component-model-async-bytes")]
343unsafe impl WriteBuffer<u8> for Cursor<BytesMut> {
344    fn remaining(&self) -> &[u8] {
345        &self.get_ref()[usize::try_from(self.position()).unwrap()..]
346    }
347
348    fn skip(&mut self, count: usize) {
349        assert!(count <= self.remaining().len());
350        self.set_position(
351            self.position()
352                .checked_add(u64::try_from(count).unwrap())
353                .unwrap(),
354        );
355    }
356
357    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<u8>])) {
358        assert!(count <= self.remaining().len());
359        fun(unsafe_byte_slice(self.remaining()));
360        self.skip(count);
361    }
362}
363
364#[cfg(feature = "component-model-async-bytes")]
365impl ReadBuffer<u8> for BytesMut {
366    fn extend<I: IntoIterator<Item = u8>>(&mut self, iter: I) {
367        Extend::extend(self, iter)
368    }
369
370    fn remaining_capacity(&self) -> usize {
371        self.capacity().checked_sub(self.len()).unwrap()
372    }
373
374    fn move_from(&mut self, input: &mut dyn WriteBuffer<u8>, count: usize) {
375        assert!(count <= self.remaining_capacity());
376        input.take(count, &mut |slice| {
377            // SAFETY: per the contract of `WriteBuffer` all the elements of
378            // the input `slice` are fully initialized so this is safe
379            // to reinterpret the slice.
380            let slice = unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) };
381            self.extend_from_slice(slice);
382        });
383    }
384}
385
386#[cfg(feature = "component-model-async-bytes")]
387fn unsafe_byte_slice(slice: &[u8]) -> &[MaybeUninit<u8>] {
388    // SAFETY: it's always safe to interpret a slice of items as a
389    // possibly-initialized slice of items.
390    unsafe { mem::transmute::<&[u8], &[MaybeUninit<u8>]>(slice) }
391}