Skip to main content

wasmtime/runtime/component/concurrent/futures_and_streams/
buffers.rs

1use crate::prelude::*;
2#[cfg(feature = "component-model-bytes")]
3use bytes::{Bytes, BytesMut};
4use core::mem::{self, MaybeUninit};
5use core::slice;
6
7// Inner module here to restrict possible readers of the fields of
8// `UntypedWriteBuffer`.
9pub use untyped::*;
10mod untyped {
11    use super::WriteBuffer;
12    use crate::vm::SendSyncPtr;
13    use core::any::TypeId;
14    use core::marker;
15    use core::mem;
16    use core::ptr::NonNull;
17
18    /// Helper structure to type-erase the `T` in `WriteBuffer<T>`.
19    ///
20    /// This is constructed with a `&mut dyn WriteBuffer<T>` and then can only
21    /// be viewed as `&mut dyn WriteBuffer<T>` as well. The `T`, however, is
22    /// carried through methods rather than the struct itself.
23    ///
24    /// Note that this structure has a lifetime `'a` which forces an active
25    /// borrow on the original buffer passed in.
26    pub struct UntypedWriteBuffer<'a> {
27        element_type_id: TypeId,
28        buf: SendSyncPtr<dyn WriteBuffer<()>>,
29        _marker: marker::PhantomData<&'a mut dyn WriteBuffer<()>>,
30    }
31
32    /// Helper structure to transmute between `WriteBuffer<T>` and
33    /// `WriteBuffer<()>`.
34    union ReinterpretWriteBuffer<T> {
35        typed: *mut dyn WriteBuffer<T>,
36        untyped: *mut dyn WriteBuffer<()>,
37    }
38
39    impl<'a> UntypedWriteBuffer<'a> {
40        /// Creates a new `UntypedWriteBuffer` from the `buf` provided.
41        ///
42        /// The returned value can be used with the `get_mut` method to get the
43        /// original write buffer back.
44        pub fn new<T: 'static>(buf: &'a mut dyn WriteBuffer<T>) -> UntypedWriteBuffer<'a> {
45            UntypedWriteBuffer {
46                element_type_id: TypeId::of::<T>(),
47                // SAFETY: this is `unsafe` due to reading union fields. That
48                // is safe here because `typed` and `untyped` have the same size
49                // and we're otherwise reinterpreting a raw pointer with a type
50                // parameter to one without one.
51                buf: SendSyncPtr::new(
52                    NonNull::new(unsafe {
53                        let r = ReinterpretWriteBuffer { typed: buf };
54                        assert_eq!(mem::size_of_val(&r.typed), mem::size_of_val(&r.untyped));
55                        r.untyped
56                    })
57                    .unwrap(),
58                ),
59                _marker: marker::PhantomData,
60            }
61        }
62
63        /// Acquires the underlying `WriteBuffer<T>` this was created with.
64        ///
65        /// # Panics
66        ///
67        /// Panics if `T` does not match the type that this was created with.
68        pub fn get_mut<T: 'static>(&mut self) -> &mut dyn WriteBuffer<T> {
69            assert_eq!(self.element_type_id, TypeId::of::<T>());
70            // SAFETY: the `T` has been checked with `TypeId` and this
71            // structure also is proof of valid existence of the original
72            // `&mut WriteBuffer<T>`, so taking the raw pointer back to a safe
73            // reference is valid.
74            unsafe {
75                &mut *ReinterpretWriteBuffer {
76                    untyped: self.buf.as_ptr(),
77                }
78                .typed
79            }
80        }
81    }
82}
83
84/// Trait representing a buffer which may be written to a `StreamWriter`.
85///
86/// See also [`crate::component::StreamProducer`].
87///
88/// # Unsafety
89///
90/// This trait is unsafe due to the contract of the `take` function. This trait
91/// is only safe to implement if the `take` function is implemented correctly,
92/// namely that all the items passed to the closure are fully initialized for
93/// `T`.
94pub unsafe trait WriteBuffer<T>: Send + Sync + 'static {
95    /// Slice of items remaining to be read.
96    fn remaining(&self) -> &[T];
97
98    /// Skip and drop the specified number of items.
99    fn skip(&mut self, count: usize);
100
101    /// Take ownership of the specified number of items.
102    ///
103    /// This function will take `count` items from `self` and pass them as a
104    /// contiguous slice to the closure `fun` provided. The `fun` closure may
105    /// assume that the items are all fully initialized and available to read.
106    /// It is expected that `fun` will read all the items provided. Any items
107    /// that aren't read by `fun` will be leaked.
108    ///
109    /// # Panics
110    ///
111    /// Panics if `count` is larger than `self.remaining()`. If `fun` panics
112    /// then items may be leaked.
113    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<T>]));
114}
115
116/// Trait representing a buffer which may be used to read from a `StreamReader`.
117///
118/// See also [`crate::component::Source`].
119pub trait ReadBuffer<T>: Send + Sync + 'static {
120    /// Move the specified items into this buffer.
121    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I);
122
123    /// Number of items which may be read before this buffer is full.
124    fn remaining_capacity(&self) -> usize;
125
126    /// Move (i.e. take ownership of) the specified items into this buffer.
127    ///
128    /// This method will drain `count` items from the `input` provided and move
129    /// ownership into this buffer.
130    ///
131    /// # Panics
132    ///
133    /// This method will panic if `count` is larger than
134    /// `self.remaining_capacity()` or if it's larger than `input.remaining()`.
135    fn move_from(&mut self, input: &mut dyn WriteBuffer<T>, count: usize);
136}
137
138pub(super) struct Extender<'a, B>(pub(super) &'a mut B);
139
140impl<T, B: ReadBuffer<T>> Extend<T> for Extender<'_, B> {
141    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
142        self.0.extend(iter)
143    }
144}
145
146// SAFETY: the `take` implementation below guarantees that the `fun` closure is
147// provided with fully initialized items.
148unsafe impl<T: Send + Sync + 'static> WriteBuffer<T> for Option<T> {
149    fn remaining(&self) -> &[T] {
150        if let Some(me) = self {
151            slice::from_ref(me)
152        } else {
153            &[]
154        }
155    }
156
157    fn skip(&mut self, count: usize) {
158        match count {
159            0 => {}
160            1 => {
161                assert!(self.is_some());
162                *self = None;
163            }
164            _ => panic!("cannot skip more than {} item(s)", self.remaining().len()),
165        }
166    }
167
168    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<T>])) {
169        match count {
170            0 => fun(&mut []),
171            1 => {
172                let mut item = MaybeUninit::new(self.take().unwrap());
173                fun(slice::from_mut(&mut item));
174            }
175            _ => panic!("cannot forget more than {} item(s)", self.remaining().len()),
176        }
177    }
178}
179
180impl<T: Send + Sync + 'static> ReadBuffer<T> for Option<T> {
181    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
182        let mut iter = iter.into_iter();
183        if self.is_none() {
184            *self = iter.next();
185        }
186        assert!(iter.next().is_none());
187    }
188
189    fn remaining_capacity(&self) -> usize {
190        if self.is_some() { 0 } else { 1 }
191    }
192
193    fn move_from(&mut self, input: &mut dyn WriteBuffer<T>, count: usize) {
194        match count {
195            0 => {}
196            1 => {
197                assert!(self.is_none());
198                input.take(1, &mut |slice| {
199                    // SAFETY: Per the `WriteBuffer` trait contract this block
200                    // has ownership of the items in `slice` and they're all
201                    // valid to take.
202                    unsafe {
203                        *self = Some(slice[0].assume_init_read());
204                    }
205                });
206            }
207            _ => panic!(
208                "cannot take more than {} item(s)",
209                self.remaining_capacity()
210            ),
211        }
212    }
213}
214
215/// A `WriteBuffer` implementation, backed by a `Vec<u8>`, a position, and a limit.
216pub struct SliceBuffer {
217    buffer: Vec<u8>,
218    offset: usize,
219    limit: usize,
220}
221
222impl SliceBuffer {
223    pub fn new(buffer: Vec<u8>, offset: usize, limit: usize) -> Self {
224        assert!(offset <= limit);
225        assert!(limit <= buffer.len());
226        Self {
227            buffer,
228            offset,
229            limit,
230        }
231    }
232
233    pub fn into_parts(self) -> (Vec<u8>, usize, usize) {
234        (self.buffer, self.offset, self.limit)
235    }
236}
237
238// SAFETY: the `take` implementation below guarantees that the `fun` closure is
239// provided with fully initialized items due to all elements in the slice being
240// initialized.
241unsafe impl WriteBuffer<u8> for SliceBuffer {
242    fn remaining(&self) -> &[u8] {
243        &self.buffer[self.offset..self.limit]
244    }
245
246    fn skip(&mut self, count: usize) {
247        assert!(self.offset + count <= self.limit);
248        self.offset += count;
249    }
250
251    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<u8>])) {
252        assert!(count <= self.remaining().len());
253        self.offset += count;
254        // SAFETY: Transmuting from `&[u8]` to `&[MaybeUninit<u8>]` should
255        // always be sound.
256        fun(unsafe {
257            mem::transmute::<&[u8], &[MaybeUninit<u8>]>(
258                &self.buffer[self.offset - count..self.offset],
259            )
260        });
261    }
262}
263
264/// A `WriteBuffer` implementation, backed by a `Vec`.
265pub struct VecBuffer<T> {
266    buffer: Vec<MaybeUninit<T>>,
267    offset: usize,
268}
269
270impl<T> Default for VecBuffer<T> {
271    fn default() -> Self {
272        Self::with_capacity(0)
273    }
274}
275
276impl<T> VecBuffer<T> {
277    /// Create a new instance with the specified capacity.
278    pub fn with_capacity(capacity: usize) -> Self {
279        Self {
280            buffer: Vec::with_capacity(capacity),
281            offset: 0,
282        }
283    }
284
285    /// Reset the state of this buffer, removing all items and preserving its
286    /// capacity.
287    pub fn reset(&mut self) {
288        self.skip_(self.remaining_().len());
289        self.buffer.clear();
290        self.offset = 0;
291    }
292
293    fn remaining_(&self) -> &[T] {
294        // SAFETY: This relies on the invariant (upheld in the other methods of
295        // this type) that all the elements from `self.offset` onward are
296        // initialized and valid for `self.buffer`.
297        unsafe { mem::transmute::<&[MaybeUninit<T>], &[T]>(&self.buffer[self.offset..]) }
298    }
299
300    fn skip_(&mut self, count: usize) {
301        assert!(count <= self.remaining_().len());
302        for item in &mut self.buffer[self.offset..][..count] {
303            // Note that the offset is incremented first here to ensure that if
304            // any destructors panic we don't attempt to re-drop the item.
305            self.offset += 1;
306            // SAFETY: See comment in `Self::remaining`
307            unsafe {
308                item.assume_init_drop();
309            }
310        }
311    }
312}
313
314// SAFETY: the `take` implementation below guarantees that the `fun` closure is
315// provided with fully initialized items due to `self.offset`-and-onwards being
316// always initialized.
317unsafe impl<T: Send + Sync + 'static> WriteBuffer<T> for VecBuffer<T> {
318    fn remaining(&self) -> &[T] {
319        self.remaining_()
320    }
321
322    fn skip(&mut self, count: usize) {
323        self.skip_(count)
324    }
325
326    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<T>])) {
327        assert!(count <= self.remaining().len());
328        // Note that the offset here is incremented before `fun` is called to
329        // ensure that if `fun` panics that the items are still considered
330        // transferred.
331        self.offset += count;
332        fun(&self.buffer[self.offset - count..self.offset]);
333    }
334}
335
336impl<T> From<Vec<T>> for VecBuffer<T> {
337    fn from(buffer: Vec<T>) -> Self {
338        Self {
339            // SAFETY: Transmuting from `Vec<T>` to `Vec<MaybeUninit<T>>` should
340            // be sound for any `T`.
341            buffer: unsafe { mem::transmute::<Vec<T>, Vec<MaybeUninit<T>>>(buffer) },
342            offset: 0,
343        }
344    }
345}
346
347impl<T> Drop for VecBuffer<T> {
348    fn drop(&mut self) {
349        self.reset();
350    }
351}
352
353impl<T: Send + Sync + 'static> ReadBuffer<T> for Vec<T> {
354    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
355        Extend::extend(self, iter)
356    }
357
358    fn remaining_capacity(&self) -> usize {
359        self.capacity().checked_sub(self.len()).unwrap()
360    }
361
362    fn move_from(&mut self, input: &mut dyn WriteBuffer<T>, count: usize) {
363        assert!(count <= self.remaining_capacity());
364        input.take(count, &mut |slice| {
365            for item in slice {
366                // SAFETY: Per the `WriteBuffer` implementation contract this
367                // function has exclusive ownership of all items in `slice` so
368                // this is safe to take and transfer them here.
369                self.push(unsafe { item.assume_init_read() });
370            }
371        });
372    }
373}
374
375// SAFETY: the `take` implementation below guarantees that the `fun` closure is
376// provided with fully initialized items.
377#[cfg(feature = "component-model-bytes")]
378unsafe impl WriteBuffer<u8> for Bytes {
379    fn remaining(&self) -> &[u8] {
380        self
381    }
382
383    fn skip(&mut self, count: usize) {
384        let _prefix = self.split_to(count);
385    }
386
387    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<u8>])) {
388        let prefix = self.split_to(count);
389        fun(unsafe_byte_slice(&prefix));
390    }
391}
392
393// SAFETY: the `take` implementation below guarantees that the `fun` closure is
394// provided with fully initialized items.
395#[cfg(feature = "component-model-bytes")]
396unsafe impl WriteBuffer<u8> for BytesMut {
397    fn remaining(&self) -> &[u8] {
398        self
399    }
400
401    fn skip(&mut self, count: usize) {
402        let _prefix = self.split_to(count);
403    }
404
405    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<u8>])) {
406        let prefix = self.split_to(count);
407        fun(unsafe_byte_slice(&prefix));
408    }
409}
410
411#[cfg(feature = "component-model-bytes")]
412impl ReadBuffer<u8> for BytesMut {
413    fn extend<I: IntoIterator<Item = u8>>(&mut self, iter: I) {
414        Extend::extend(self, iter)
415    }
416
417    fn remaining_capacity(&self) -> usize {
418        self.capacity().checked_sub(self.len()).unwrap()
419    }
420
421    fn move_from(&mut self, input: &mut dyn WriteBuffer<u8>, count: usize) {
422        assert!(count <= self.remaining_capacity());
423        input.take(count, &mut |slice| {
424            // SAFETY: per the contract of `WriteBuffer` all the elements of
425            // the input `slice` are fully initialized so this is safe
426            // to reinterpret the slice.
427            let slice = unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) };
428            self.extend_from_slice(slice);
429        });
430    }
431}
432
433#[cfg(feature = "component-model-bytes")]
434fn unsafe_byte_slice(slice: &[u8]) -> &[MaybeUninit<u8>] {
435    // SAFETY: it's always safe to interpret a slice of items as a
436    // possibly-initialized slice of items.
437    unsafe { mem::transmute::<&[u8], &[MaybeUninit<u8>]>(slice) }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_vec_buffer_take() {
446        let mut buf = VecBuffer::from(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
447        let mut dst = Vec::new();
448        dst.reserve(1);
449        dst.move_from(&mut buf, 1);
450        assert_eq!(buf.remaining().len(), 2);
451        assert_eq!(dst.len(), 1);
452        None.move_from(&mut buf, 1);
453        assert_eq!(buf.remaining().len(), 1);
454        assert_eq!(dst.len(), 1);
455    }
456
457    #[test]
458    fn test_slice_buffer_take() {
459        let mut buf = SliceBuffer::new(vec![1, 2, 3], 0, 3);
460        let mut dst = Vec::new();
461        dst.reserve(1);
462        dst.move_from(&mut buf, 1);
463        assert_eq!(buf.remaining().len(), 2);
464        assert_eq!(dst.len(), 1);
465    }
466
467    #[test]
468    #[cfg(feature = "component-model-bytes")]
469    fn test_cursor_bytes_take() {
470        let mut buf = Bytes::from(&b"123"[..]);
471        let mut dst = Vec::new();
472        dst.reserve(1);
473        dst.move_from(&mut buf, 1);
474        assert_eq!(buf.remaining().len(), 2);
475        assert_eq!(dst.len(), 1);
476
477        let mut dst = BytesMut::new();
478        dst.reserve(1);
479        dst.move_from(&mut buf, 1);
480        assert_eq!(buf.remaining().len(), 1);
481        assert_eq!(dst.len(), 1);
482    }
483
484    #[test]
485    #[cfg(feature = "component-model-bytes")]
486    fn test_cursor_bytes_mut_take() {
487        let mut buf = BytesMut::from(&b"123"[..]);
488        let mut dst = Vec::new();
489        dst.reserve(1);
490        dst.move_from(&mut buf, 1);
491        assert_eq!(buf.remaining().len(), 2);
492        assert_eq!(dst.len(), 1);
493
494        let mut dst = BytesMut::new();
495        dst.reserve(1);
496        dst.move_from(&mut buf, 1);
497        assert_eq!(buf.remaining().len(), 1);
498        assert_eq!(dst.len(), 1);
499    }
500}