Skip to main content

wasmtime/runtime/vm/instance/allocator/pooling/
unix_stack_pool.rs

1#![cfg_attr(asan, allow(dead_code))]
2
3use super::index_allocator::{SimpleIndexAllocator, SlotId};
4use crate::prelude::*;
5use crate::runtime::vm::{
6    HostAlignedByteCount, Mmap, PoolingInstanceAllocatorConfig, mmap::AlignedLength,
7};
8
9/// Represents a pool of execution stacks (used for the async fiber implementation).
10///
11/// Each index into the pool represents a single execution stack. The maximum number of
12/// stacks is the same as the maximum number of instances.
13///
14/// As stacks grow downwards, each stack starts (lowest address) with a guard page
15/// that can be used to detect stack overflow.
16///
17/// The top of the stack (starting stack pointer) is returned when a stack is allocated
18/// from the pool.
19#[derive(Debug)]
20pub struct StackPool {
21    mapping: Mmap<AlignedLength>,
22    stack_size: HostAlignedByteCount,
23    max_stacks: usize,
24    page_size: HostAlignedByteCount,
25    index_allocator: SimpleIndexAllocator,
26    async_stack_zeroing: bool,
27    async_stack_keep_resident: HostAlignedByteCount,
28}
29
30impl StackPool {
31    #[cfg(test)]
32    pub fn enabled() -> bool {
33        true
34    }
35
36    pub fn new(config: &PoolingInstanceAllocatorConfig) -> Result<Self> {
37        use rustix::mm::{MprotectFlags, mprotect};
38
39        let page_size = HostAlignedByteCount::host_page_size();
40
41        // Add a page to the stack size for the guard page when using fiber stacks
42        let stack_size = if config.stack_size == 0 {
43            HostAlignedByteCount::ZERO
44        } else {
45            HostAlignedByteCount::new_rounded_up(config.stack_size)
46                .and_then(|size| size.checked_add(HostAlignedByteCount::host_page_size()))
47                .context("stack size exceeds addressable memory")?
48        };
49
50        let max_stacks = usize::try_from(config.limits.total_stacks).unwrap();
51
52        let allocation_size = stack_size
53            .checked_mul(max_stacks)
54            .context("total size of execution stacks exceeds addressable memory")?;
55
56        let mapping = Mmap::accessible_reserved(allocation_size, allocation_size)
57            .context("failed to create stack pool mapping")?;
58
59        // Set up the stack guard pages.
60        if !allocation_size.is_zero() {
61            unsafe {
62                for i in 0..max_stacks {
63                    // Safety: i < max_stacks and we've already checked that
64                    // stack_size * max_stacks is valid.
65                    let offset = stack_size.unchecked_mul(i);
66                    // Make the stack guard page inaccessible.
67                    let bottom_of_stack = mapping.as_ptr().add(offset.byte_count()).cast_mut();
68                    mprotect(
69                        bottom_of_stack.cast(),
70                        page_size.byte_count(),
71                        MprotectFlags::empty(),
72                    )
73                    .context("failed to protect stack guard page")?;
74                }
75            }
76        }
77
78        Ok(Self {
79            mapping,
80            stack_size,
81            max_stacks,
82            page_size,
83            async_stack_zeroing: config.async_stack_zeroing,
84            async_stack_keep_resident: HostAlignedByteCount::new_rounded_up(
85                config.async_stack_keep_resident,
86            )?,
87            index_allocator: SimpleIndexAllocator::new(config.limits.total_stacks),
88        })
89    }
90
91    /// Are there zero slots in use right now?
92    pub fn is_empty(&self) -> bool {
93        self.index_allocator.is_empty()
94    }
95
96    /// Allocate a new fiber.
97    pub fn allocate(&self) -> Result<wasmtime_fiber::FiberStack> {
98        if self.stack_size.is_zero() {
99            bail!("pooling allocator not configured to enable fiber stack allocation");
100        }
101
102        let index = self
103            .index_allocator
104            .alloc()
105            .ok_or_else(|| super::PoolConcurrencyLimitError::new(self.max_stacks, "fibers"))?
106            .index();
107
108        assert!(index < self.max_stacks);
109
110        unsafe {
111            // Remove the guard page from the size
112            let size_without_guard = self.stack_size.checked_sub(self.page_size).expect(
113                "self.stack_size is host-page-aligned and is > 0,\
114                 so it must be >= self.page_size",
115            );
116
117            let bottom_of_stack = self
118                .mapping
119                .as_ptr()
120                .add(self.stack_size.unchecked_mul(index).byte_count())
121                .cast_mut();
122
123            let stack = wasmtime_fiber::FiberStack::from_raw_parts(
124                bottom_of_stack,
125                self.page_size.byte_count(),
126                size_without_guard.byte_count(),
127            )?;
128            Ok(stack)
129        }
130    }
131
132    /// Zero the given stack, if we are configured to do so.
133    ///
134    /// This will call the given `decommit` function for each region of memory
135    /// that should be decommitted. It is the caller's responsibility to ensure
136    /// that those decommits happen before this stack is reused.
137    ///
138    /// # Panics
139    ///
140    /// `zero_stack` panics if the passed in `stack` was not created by
141    /// [`Self::allocate`].
142    ///
143    /// # Safety
144    ///
145    /// The stack must no longer be in use, and ready for returning to the pool
146    /// after it is zeroed and decommitted.
147    pub unsafe fn zero_stack(
148        &self,
149        stack: &mut wasmtime_fiber::FiberStack,
150        mut decommit: impl FnMut(*mut u8, usize),
151    ) -> usize {
152        assert!(stack.is_from_raw_parts());
153        assert!(
154            !self.stack_size.is_zero(),
155            "pooling allocator not configured to enable fiber stack allocation \
156             (Self::allocate should have returned an error)"
157        );
158
159        if !self.async_stack_zeroing {
160            return 0;
161        }
162
163        let top = stack
164            .top()
165            .expect("fiber stack not allocated from the pool") as usize;
166
167        let base = self.mapping.as_ptr() as usize;
168        let len = self.mapping.len();
169        assert!(
170            top > base && top <= (base + len),
171            "fiber stack top pointer not in range"
172        );
173
174        // Remove the guard page from the size.
175        let stack_size = self.stack_size.checked_sub(self.page_size).expect(
176            "self.stack_size is host-page-aligned and is > 0,\
177             so it must be >= self.page_size",
178        );
179        let bottom_of_stack = top - stack_size.byte_count();
180        let start_of_stack = bottom_of_stack - self.page_size.byte_count();
181        assert!(start_of_stack >= base && start_of_stack < (base + len));
182        assert!((start_of_stack - base) % self.stack_size.byte_count() == 0);
183
184        // Manually zero the top of the stack to keep the pages resident in
185        // memory and avoid future page faults. Use the system to deallocate
186        // pages past this. This hopefully strikes a reasonable balance between:
187        //
188        // * memset for the whole range is probably expensive
189        // * madvise for the whole range incurs expensive future page faults
190        // * most threads probably don't use most of the stack anyway
191        let size_to_memset = stack_size.min(self.async_stack_keep_resident);
192        let rest = stack_size
193            .checked_sub(size_to_memset)
194            .expect("stack_size >= size_to_memset");
195
196        // SAFETY: this function's own contract requires that the stack is not
197        // in use so it's safe to pave over part of it with zero.
198        unsafe {
199            std::ptr::write_bytes(
200                (bottom_of_stack + rest.byte_count()) as *mut u8,
201                0,
202                size_to_memset.byte_count(),
203            );
204        }
205
206        // Use the system to reset remaining stack pages to zero.
207        decommit(bottom_of_stack as _, rest.byte_count());
208
209        size_to_memset.byte_count()
210    }
211
212    /// Deallocate a previously-allocated fiber.
213    ///
214    /// # Safety
215    ///
216    /// The fiber must have been allocated by this pool, must be in an allocated
217    /// state, and must never be used again.
218    ///
219    /// The caller must have already called `zero_stack` on the fiber stack and
220    /// flushed any enqueued decommits for this stack's memory.
221    pub unsafe fn deallocate(&self, stack: wasmtime_fiber::FiberStack, bytes_resident: usize) {
222        assert!(stack.is_from_raw_parts());
223
224        let top = stack
225            .top()
226            .expect("fiber stack not allocated from the pool") as usize;
227
228        let base = self.mapping.as_ptr() as usize;
229        let len = self.mapping.len();
230        assert!(
231            top > base && top <= (base + len),
232            "fiber stack top pointer not in range"
233        );
234
235        // Remove the guard page from the size
236        let stack_size = self.stack_size.byte_count() - self.page_size.byte_count();
237        let bottom_of_stack = top - stack_size;
238        let start_of_stack = bottom_of_stack - self.page_size.byte_count();
239        assert!(start_of_stack >= base && start_of_stack < (base + len));
240        assert!((start_of_stack - base) % self.stack_size.byte_count() == 0);
241
242        let index = (start_of_stack - base) / self.stack_size.byte_count();
243        assert!(index < self.max_stacks);
244        let index = u32::try_from(index).unwrap();
245
246        self.index_allocator.free(SlotId(index), bytes_resident);
247    }
248
249    pub fn unused_warm_slots(&self) -> u32 {
250        self.index_allocator.unused_warm_slots()
251    }
252
253    pub fn unused_bytes_resident(&self) -> Option<usize> {
254        if self.async_stack_zeroing {
255            Some(self.index_allocator.unused_bytes_resident())
256        } else {
257            None
258        }
259    }
260}
261
262#[cfg(all(test, unix, feature = "async", not(miri), not(asan)))]
263mod tests {
264    use super::*;
265    use crate::runtime::vm::InstanceLimits;
266
267    #[test]
268    fn test_stack_pool() -> Result<()> {
269        let config = PoolingInstanceAllocatorConfig {
270            limits: InstanceLimits {
271                total_stacks: 10,
272                ..Default::default()
273            },
274            stack_size: 1,
275            async_stack_zeroing: true,
276            ..PoolingInstanceAllocatorConfig::default()
277        };
278        let pool = StackPool::new(&config)?;
279
280        let native_page_size = crate::runtime::vm::host_page_size();
281        assert_eq!(pool.stack_size, 2 * native_page_size);
282        assert_eq!(pool.max_stacks, 10);
283        assert_eq!(pool.page_size, native_page_size);
284
285        assert_eq!(pool.index_allocator.testing_freelist(), []);
286
287        let base = pool.mapping.as_ptr() as usize;
288
289        let mut stacks = Vec::new();
290        for i in 0..10 {
291            let stack = pool.allocate().expect("allocation should succeed");
292            assert_eq!(
293                ((stack.top().unwrap() as usize - base) / pool.stack_size.byte_count()) - 1,
294                i
295            );
296            stacks.push(stack);
297        }
298
299        assert_eq!(pool.index_allocator.testing_freelist(), []);
300
301        assert!(pool.allocate().is_err(), "allocation should fail");
302
303        for stack in stacks {
304            unsafe {
305                pool.deallocate(stack, 0);
306            }
307        }
308
309        assert_eq!(
310            pool.index_allocator.testing_freelist(),
311            [
312                SlotId(0),
313                SlotId(1),
314                SlotId(2),
315                SlotId(3),
316                SlotId(4),
317                SlotId(5),
318                SlotId(6),
319                SlotId(7),
320                SlotId(8),
321                SlotId(9)
322            ],
323        );
324
325        Ok(())
326    }
327}