Skip to main content

wasmtime/runtime/component/concurrent/futures_and_streams/
buffers.rs

1#[cfg(feature = "component-model-async-bytes")]
2use bytes::{Bytes, BytesMut};
3#[cfg(feature = "component-model-async-bytes")]
4use std::io::Cursor;
5use std::mem::{self, MaybeUninit};
6use std::slice;
7use std::vec::Vec;
8
9// Inner module here to restrict possible readers of the fields of
10// `UntypedWriteBuffer`.
11pub use untyped::*;
12mod untyped {
13    use super::WriteBuffer;
14    use crate::vm::SendSyncPtr;
15    use std::any::TypeId;
16    use std::marker;
17    use std::mem;
18    use std::ptr::NonNull;
19
20    /// Helper structure to type-erase the `T` in `WriteBuffer<T>`.
21    ///
22    /// This is constructed with a `&mut dyn WriteBuffer<T>` and then can only
23    /// be viewed as `&mut dyn WriteBuffer<T>` as well. The `T`, however, is
24    /// carried through methods rather than the struct itself.
25    ///
26    /// Note that this structure has a lifetime `'a` which forces an active
27    /// borrow on the original buffer passed in.
28    pub struct UntypedWriteBuffer<'a> {
29        element_type_id: TypeId,
30        buf: SendSyncPtr<dyn WriteBuffer<()>>,
31        _marker: marker::PhantomData<&'a mut dyn WriteBuffer<()>>,
32    }
33
34    /// Helper structure to transmute between `WriteBuffer<T>` and
35    /// `WriteBuffer<()>`.
36    union ReinterpretWriteBuffer<T> {
37        typed: *mut dyn WriteBuffer<T>,
38        untyped: *mut dyn WriteBuffer<()>,
39    }
40
41    impl<'a> UntypedWriteBuffer<'a> {
42        /// Creates a new `UntypedWriteBuffer` from the `buf` provided.
43        ///
44        /// The returned value can be used with the `get_mut` method to get the
45        /// original write buffer back.
46        pub fn new<T: 'static>(buf: &'a mut dyn WriteBuffer<T>) -> UntypedWriteBuffer<'a> {
47            UntypedWriteBuffer {
48                element_type_id: TypeId::of::<T>(),
49                // SAFETY: this is `unsafe` due to reading union fields. That
50                // is safe here because `typed` and `untyped` have the same size
51                // and we're otherwise reinterpreting a raw pointer with a type
52                // parameter to one without one.
53                buf: SendSyncPtr::new(
54                    NonNull::new(unsafe {
55                        let r = ReinterpretWriteBuffer { typed: buf };
56                        assert_eq!(mem::size_of_val(&r.typed), mem::size_of_val(&r.untyped));
57                        r.untyped
58                    })
59                    .unwrap(),
60                ),
61                _marker: marker::PhantomData,
62            }
63        }
64
65        /// Acquires the underlying `WriteBuffer<T>` this was created with.
66        ///
67        /// # Panics
68        ///
69        /// Panics if `T` does not match the type that this was created with.
70        pub fn get_mut<T: 'static>(&mut self) -> &mut dyn WriteBuffer<T> {
71            assert_eq!(self.element_type_id, TypeId::of::<T>());
72            // SAFETY: the `T` has been checked with `TypeId` and this
73            // structure also is proof of valid existence of the original
74            // `&mut WriteBuffer<T>`, so taking the raw pointer back to a safe
75            // reference is valid.
76            unsafe {
77                &mut *ReinterpretWriteBuffer {
78                    untyped: self.buf.as_ptr(),
79                }
80                .typed
81            }
82        }
83    }
84}
85
86/// Trait representing a buffer which may be written to a `StreamWriter`.
87///
88/// See also [`crate::component::StreamProducer`].
89///
90/// # Unsafety
91///
92/// This trait is unsafe due to the contract of the `take` function. This trait
93/// is only safe to implement if the `take` function is implemented correctly,
94/// namely that all the items passed to the closure are fully initialized for
95/// `T`.
96pub unsafe trait WriteBuffer<T>: Send + Sync + 'static {
97    /// Slice of items remaining to be read.
98    fn remaining(&self) -> &[T];
99
100    /// Skip and drop the specified number of items.
101    fn skip(&mut self, count: usize);
102
103    /// Take ownership of the specified number of items.
104    ///
105    /// This function will take `count` items from `self` and pass them as a
106    /// contiguous slice to the closure `fun` provided. The `fun` closure may
107    /// assume that the items are all fully initialized and available to read.
108    /// It is expected that `fun` will read all the items provided. Any items
109    /// that aren't read by `fun` will be leaked.
110    ///
111    /// # Panics
112    ///
113    /// Panics if `count` is larger than `self.remaining()`. If `fun` panics
114    /// then items may be leaked.
115    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<T>]));
116}
117
118/// Trait representing a buffer which may be used to read from a `StreamReader`.
119///
120/// See also [`crate::component::Source`].
121pub trait ReadBuffer<T>: Send + Sync + 'static {
122    /// Move the specified items into this buffer.
123    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I);
124
125    /// Number of items which may be read before this buffer is full.
126    fn remaining_capacity(&self) -> usize;
127
128    /// Move (i.e. take ownership of) the specified items into this buffer.
129    ///
130    /// This method will drain `count` items from the `input` provided and move
131    /// ownership into this buffer.
132    ///
133    /// # Panics
134    ///
135    /// This method will panic if `count` is larger than
136    /// `self.remaining_capacity()` or if it's larger than `input.remaining()`.
137    fn move_from(&mut self, input: &mut dyn WriteBuffer<T>, count: usize);
138}
139
140pub(super) struct Extender<'a, B>(pub(super) &'a mut B);
141
142impl<T, B: ReadBuffer<T>> Extend<T> for Extender<'_, B> {
143    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
144        self.0.extend(iter)
145    }
146}
147
148// SAFETY: the `take` implementation below guarantees that the `fun` closure is
149// provided with fully initialized items.
150unsafe impl<T: Send + Sync + 'static> WriteBuffer<T> for Option<T> {
151    fn remaining(&self) -> &[T] {
152        if let Some(me) = self {
153            slice::from_ref(me)
154        } else {
155            &[]
156        }
157    }
158
159    fn skip(&mut self, count: usize) {
160        match count {
161            0 => {}
162            1 => {
163                assert!(self.is_some());
164                *self = None;
165            }
166            _ => panic!("cannot skip more than {} item(s)", self.remaining().len()),
167        }
168    }
169
170    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<T>])) {
171        match count {
172            0 => fun(&mut []),
173            1 => {
174                let mut item = MaybeUninit::new(self.take().unwrap());
175                fun(slice::from_mut(&mut item));
176            }
177            _ => panic!("cannot forget more than {} item(s)", self.remaining().len()),
178        }
179    }
180}
181
182impl<T: Send + Sync + 'static> ReadBuffer<T> for Option<T> {
183    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
184        let mut iter = iter.into_iter();
185        if self.is_none() {
186            *self = iter.next();
187        }
188        assert!(iter.next().is_none());
189    }
190
191    fn remaining_capacity(&self) -> usize {
192        if self.is_some() { 0 } else { 1 }
193    }
194
195    fn move_from(&mut self, input: &mut dyn WriteBuffer<T>, count: usize) {
196        match count {
197            0 => {}
198            1 => {
199                assert!(self.is_none());
200                input.take(1, &mut |slice| {
201                    // SAFETY: Per the `WriteBuffer` trait contract this block
202                    // has ownership of the items in `slice` and they're all
203                    // valid to take.
204                    unsafe {
205                        *self = Some(slice[0].assume_init_read());
206                    }
207                });
208            }
209            _ => panic!(
210                "cannot take more than {} item(s)",
211                self.remaining_capacity()
212            ),
213        }
214    }
215}
216
217/// A `WriteBuffer` implementation, backed by a `Vec<u8>`, a position, and a limit.
218pub struct SliceBuffer {
219    buffer: Vec<u8>,
220    offset: usize,
221    limit: usize,
222}
223
224impl SliceBuffer {
225    pub fn new(buffer: Vec<u8>, offset: usize, limit: usize) -> Self {
226        assert!(offset <= limit);
227        assert!(limit <= buffer.len());
228        Self {
229            buffer,
230            offset,
231            limit,
232        }
233    }
234
235    pub fn into_parts(self) -> (Vec<u8>, usize, usize) {
236        (self.buffer, self.offset, self.limit)
237    }
238}
239
240// SAFETY: the `take` implementation below guarantees that the `fun` closure is
241// provided with fully initialized items due to all elements in the slice being
242// initialized.
243unsafe impl WriteBuffer<u8> for SliceBuffer {
244    fn remaining(&self) -> &[u8] {
245        &self.buffer[self.offset..self.limit]
246    }
247
248    fn skip(&mut self, count: usize) {
249        assert!(self.offset + count <= self.limit);
250        self.offset += count;
251    }
252
253    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<u8>])) {
254        assert!(count <= self.remaining().len());
255        self.offset += count;
256        // SAFETY: Transmuting from `&[u8]` to `&[MaybeUninit<u8>]` should
257        // always be sound.
258        fun(unsafe {
259            mem::transmute::<&[u8], &[MaybeUninit<u8>]>(
260                &self.buffer[self.offset - count..self.offset],
261            )
262        });
263    }
264}
265
266/// A `WriteBuffer` implementation, backed by a `Vec`.
267pub struct VecBuffer<T> {
268    buffer: Vec<MaybeUninit<T>>,
269    offset: usize,
270}
271
272impl<T> Default for VecBuffer<T> {
273    fn default() -> Self {
274        Self::with_capacity(0)
275    }
276}
277
278impl<T> VecBuffer<T> {
279    /// Create a new instance with the specified capacity.
280    pub fn with_capacity(capacity: usize) -> Self {
281        Self {
282            buffer: Vec::with_capacity(capacity),
283            offset: 0,
284        }
285    }
286
287    /// Reset the state of this buffer, removing all items and preserving its
288    /// capacity.
289    pub fn reset(&mut self) {
290        self.skip_(self.remaining_().len());
291        self.buffer.clear();
292        self.offset = 0;
293    }
294
295    fn remaining_(&self) -> &[T] {
296        // SAFETY: This relies on the invariant (upheld in the other methods of
297        // this type) that all the elements from `self.offset` onward are
298        // initialized and valid for `self.buffer`.
299        unsafe { mem::transmute::<&[MaybeUninit<T>], &[T]>(&self.buffer[self.offset..]) }
300    }
301
302    fn skip_(&mut self, count: usize) {
303        assert!(count <= self.remaining_().len());
304        for item in &mut self.buffer[self.offset..][..count] {
305            // Note that the offset is incremented first here to ensure that if
306            // any destructors panic we don't attempt to re-drop the item.
307            self.offset += 1;
308            // SAFETY: See comment in `Self::remaining`
309            unsafe {
310                item.assume_init_drop();
311            }
312        }
313    }
314}
315
316// SAFETY: the `take` implementation below guarantees that the `fun` closure is
317// provided with fully initialized items due to `self.offset`-and-onwards being
318// always initialized.
319unsafe impl<T: Send + Sync + 'static> WriteBuffer<T> for VecBuffer<T> {
320    fn remaining(&self) -> &[T] {
321        self.remaining_()
322    }
323
324    fn skip(&mut self, count: usize) {
325        self.skip_(count)
326    }
327
328    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<T>])) {
329        assert!(count <= self.remaining().len());
330        // Note that the offset here is incremented before `fun` is called to
331        // ensure that if `fun` panics that the items are still considered
332        // transferred.
333        self.offset += count;
334        fun(&self.buffer[self.offset - count..self.offset]);
335    }
336}
337
338impl<T> From<Vec<T>> for VecBuffer<T> {
339    fn from(buffer: Vec<T>) -> Self {
340        Self {
341            // SAFETY: Transmuting from `Vec<T>` to `Vec<MaybeUninit<T>>` should
342            // be sound for any `T`.
343            buffer: unsafe { mem::transmute::<Vec<T>, Vec<MaybeUninit<T>>>(buffer) },
344            offset: 0,
345        }
346    }
347}
348
349impl<T> Drop for VecBuffer<T> {
350    fn drop(&mut self) {
351        self.reset();
352    }
353}
354
355impl<T: Send + Sync + 'static> ReadBuffer<T> for Vec<T> {
356    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
357        Extend::extend(self, iter)
358    }
359
360    fn remaining_capacity(&self) -> usize {
361        self.capacity().checked_sub(self.len()).unwrap()
362    }
363
364    fn move_from(&mut self, input: &mut dyn WriteBuffer<T>, count: usize) {
365        assert!(count <= self.remaining_capacity());
366        input.take(count, &mut |slice| {
367            for item in slice {
368                // SAFETY: Per the `WriteBuffer` implementation contract this
369                // function has exclusive ownership of all items in `slice` so
370                // this is safe to take and transfer them here.
371                self.push(unsafe { item.assume_init_read() });
372            }
373        });
374    }
375}
376
377// SAFETY: the `take` implementation below guarantees that the `fun` closure is
378// provided with fully initialized items.
379#[cfg(feature = "component-model-async-bytes")]
380unsafe impl WriteBuffer<u8> for Cursor<Bytes> {
381    fn remaining(&self) -> &[u8] {
382        &self.get_ref()[usize::try_from(self.position()).unwrap()..]
383    }
384
385    fn skip(&mut self, count: usize) {
386        assert!(
387            count <= self.remaining().len(),
388            "tried to skip {count} with {} remaining",
389            self.remaining().len()
390        );
391        self.set_position(
392            self.position()
393                .checked_add(u64::try_from(count).unwrap())
394                .unwrap(),
395        );
396    }
397
398    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<u8>])) {
399        assert!(count <= self.remaining().len());
400        fun(unsafe_byte_slice(&self.remaining()[..count]));
401        self.skip(count);
402    }
403}
404
405// SAFETY: the `take` implementation below guarantees that the `fun` closure is
406// provided with fully initialized items.
407#[cfg(feature = "component-model-async-bytes")]
408unsafe impl WriteBuffer<u8> for Cursor<BytesMut> {
409    fn remaining(&self) -> &[u8] {
410        &self.get_ref()[usize::try_from(self.position()).unwrap()..]
411    }
412
413    fn skip(&mut self, count: usize) {
414        assert!(count <= self.remaining().len());
415        self.set_position(
416            self.position()
417                .checked_add(u64::try_from(count).unwrap())
418                .unwrap(),
419        );
420    }
421
422    fn take(&mut self, count: usize, fun: &mut dyn FnMut(&[MaybeUninit<u8>])) {
423        assert!(count <= self.remaining().len());
424        fun(unsafe_byte_slice(&self.remaining()[..count]));
425        self.skip(count);
426    }
427}
428
429#[cfg(feature = "component-model-async-bytes")]
430impl ReadBuffer<u8> for BytesMut {
431    fn extend<I: IntoIterator<Item = u8>>(&mut self, iter: I) {
432        Extend::extend(self, iter)
433    }
434
435    fn remaining_capacity(&self) -> usize {
436        self.capacity().checked_sub(self.len()).unwrap()
437    }
438
439    fn move_from(&mut self, input: &mut dyn WriteBuffer<u8>, count: usize) {
440        assert!(count <= self.remaining_capacity());
441        input.take(count, &mut |slice| {
442            // SAFETY: per the contract of `WriteBuffer` all the elements of
443            // the input `slice` are fully initialized so this is safe
444            // to reinterpret the slice.
445            let slice = unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) };
446            self.extend_from_slice(slice);
447        });
448    }
449}
450
451#[cfg(feature = "component-model-async-bytes")]
452fn unsafe_byte_slice(slice: &[u8]) -> &[MaybeUninit<u8>] {
453    // SAFETY: it's always safe to interpret a slice of items as a
454    // possibly-initialized slice of items.
455    unsafe { mem::transmute::<&[u8], &[MaybeUninit<u8>]>(slice) }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461    use crate::prelude::*;
462
463    #[test]
464    fn test_vec_buffer_take() {
465        let mut buf = VecBuffer::from(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
466        let mut dst = Vec::new();
467        dst.reserve(1);
468        dst.move_from(&mut buf, 1);
469        assert_eq!(buf.remaining().len(), 2);
470        assert_eq!(dst.len(), 1);
471        None.move_from(&mut buf, 1);
472        assert_eq!(buf.remaining().len(), 1);
473        assert_eq!(dst.len(), 1);
474    }
475
476    #[test]
477    fn test_slice_buffer_take() {
478        let mut buf = SliceBuffer::new(vec![1, 2, 3], 0, 3);
479        let mut dst = Vec::new();
480        dst.reserve(1);
481        dst.move_from(&mut buf, 1);
482        assert_eq!(buf.remaining().len(), 2);
483        assert_eq!(dst.len(), 1);
484    }
485
486    #[test]
487    #[cfg(feature = "component-model-async-bytes")]
488    fn test_cursor_bytes_take() {
489        let mut buf = Cursor::new(Bytes::from(&b"123"[..]));
490        let mut dst = Vec::new();
491        dst.reserve(1);
492        dst.move_from(&mut buf, 1);
493        assert_eq!(buf.remaining().len(), 2);
494        assert_eq!(dst.len(), 1);
495
496        let mut dst = BytesMut::new();
497        dst.reserve(1);
498        dst.move_from(&mut buf, 1);
499        assert_eq!(buf.remaining().len(), 1);
500        assert_eq!(dst.len(), 1);
501    }
502
503    #[test]
504    #[cfg(feature = "component-model-async-bytes")]
505    fn test_cursor_bytes_mut_take() {
506        let mut buf = Cursor::new(BytesMut::from(&b"123"[..]));
507        let mut dst = Vec::new();
508        dst.reserve(1);
509        dst.move_from(&mut buf, 1);
510        assert_eq!(buf.remaining().len(), 2);
511        assert_eq!(dst.len(), 1);
512
513        let mut dst = BytesMut::new();
514        dst.reserve(1);
515        dst.move_from(&mut buf, 1);
516        assert_eq!(buf.remaining().len(), 1);
517        assert_eq!(dst.len(), 1);
518    }
519}