wasmtime/runtime/vm/instance/allocator/pooling/
unix_stack_pool.rs

1#![cfg_attr(asan, allow(dead_code))]
2
3use super::index_allocator::{SimpleIndexAllocator, SlotId};
4use crate::prelude::*;
5use crate::runtime::vm::sys::vm::commit_pages;
6use crate::runtime::vm::{
7    HostAlignedByteCount, Mmap, PoolingInstanceAllocatorConfig, mmap::AlignedLength,
8};
9
10/// Represents a pool of execution stacks (used for the async fiber implementation).
11///
12/// Each index into the pool represents a single execution stack. The maximum number of
13/// stacks is the same as the maximum number of instances.
14///
15/// As stacks grow downwards, each stack starts (lowest address) with a guard page
16/// that can be used to detect stack overflow.
17///
18/// The top of the stack (starting stack pointer) is returned when a stack is allocated
19/// from the pool.
20#[derive(Debug)]
21pub struct StackPool {
22    mapping: Mmap<AlignedLength>,
23    stack_size: HostAlignedByteCount,
24    max_stacks: usize,
25    page_size: HostAlignedByteCount,
26    index_allocator: SimpleIndexAllocator,
27    async_stack_zeroing: bool,
28    async_stack_keep_resident: HostAlignedByteCount,
29}
30
31impl StackPool {
32    #[cfg(test)]
33    pub fn enabled() -> bool {
34        true
35    }
36
37    pub fn new(config: &PoolingInstanceAllocatorConfig) -> Result<Self> {
38        use rustix::mm::{MprotectFlags, mprotect};
39
40        let page_size = HostAlignedByteCount::host_page_size();
41
42        // Add a page to the stack size for the guard page when using fiber stacks
43        let stack_size = if config.stack_size == 0 {
44            HostAlignedByteCount::ZERO
45        } else {
46            HostAlignedByteCount::new_rounded_up(config.stack_size)
47                .and_then(|size| size.checked_add(HostAlignedByteCount::host_page_size()))
48                .context("stack size exceeds addressable memory")?
49        };
50
51        let max_stacks = usize::try_from(config.limits.total_stacks).unwrap();
52
53        let allocation_size = stack_size
54            .checked_mul(max_stacks)
55            .context("total size of execution stacks exceeds addressable memory")?;
56
57        let mapping = Mmap::accessible_reserved(allocation_size, allocation_size)
58            .context("failed to create stack pool mapping")?;
59
60        // Set up the stack guard pages.
61        if !allocation_size.is_zero() {
62            unsafe {
63                for i in 0..max_stacks {
64                    // Safety: i < max_stacks and we've already checked that
65                    // stack_size * max_stacks is valid.
66                    let offset = stack_size.unchecked_mul(i);
67                    // Make the stack guard page inaccessible.
68                    let bottom_of_stack = mapping.as_ptr().add(offset.byte_count()).cast_mut();
69                    mprotect(
70                        bottom_of_stack.cast(),
71                        page_size.byte_count(),
72                        MprotectFlags::empty(),
73                    )
74                    .context("failed to protect stack guard page")?;
75                }
76            }
77        }
78
79        Ok(Self {
80            mapping,
81            stack_size,
82            max_stacks,
83            page_size,
84            async_stack_zeroing: config.async_stack_zeroing,
85            async_stack_keep_resident: HostAlignedByteCount::new_rounded_up(
86                config.async_stack_keep_resident,
87            )?,
88            index_allocator: SimpleIndexAllocator::new(config.limits.total_stacks),
89        })
90    }
91
92    /// Are there zero slots in use right now?
93    pub fn is_empty(&self) -> bool {
94        self.index_allocator.is_empty()
95    }
96
97    /// Allocate a new fiber.
98    pub fn allocate(&self) -> Result<wasmtime_fiber::FiberStack> {
99        if self.stack_size.is_zero() {
100            bail!("pooling allocator not configured to enable fiber stack allocation");
101        }
102
103        let index = self
104            .index_allocator
105            .alloc()
106            .ok_or_else(|| super::PoolConcurrencyLimitError::new(self.max_stacks, "fibers"))?
107            .index();
108
109        assert!(index < self.max_stacks);
110
111        unsafe {
112            // Remove the guard page from the size
113            let size_without_guard = self.stack_size.checked_sub(self.page_size).expect(
114                "self.stack_size is host-page-aligned and is > 0,\
115                 so it must be >= self.page_size",
116            );
117
118            let bottom_of_stack = self
119                .mapping
120                .as_ptr()
121                .add(self.stack_size.unchecked_mul(index).byte_count())
122                .cast_mut();
123
124            commit_pages(bottom_of_stack, size_without_guard.byte_count())?;
125
126            let stack = wasmtime_fiber::FiberStack::from_raw_parts(
127                bottom_of_stack,
128                self.page_size.byte_count(),
129                size_without_guard.byte_count(),
130            )?;
131            Ok(stack)
132        }
133    }
134
135    /// Zero the given stack, if we are configured to do so.
136    ///
137    /// This will call the given `decommit` function for each region of memory
138    /// that should be decommitted. It is the caller's responsibility to ensure
139    /// that those decommits happen before this stack is reused.
140    ///
141    /// # Panics
142    ///
143    /// `zero_stack` panics if the passed in `stack` was not created by
144    /// [`Self::allocate`].
145    ///
146    /// # Safety
147    ///
148    /// The stack must no longer be in use, and ready for returning to the pool
149    /// after it is zeroed and decommitted.
150    pub unsafe fn zero_stack(
151        &self,
152        stack: &mut wasmtime_fiber::FiberStack,
153        mut decommit: impl FnMut(*mut u8, usize),
154    ) -> usize {
155        assert!(stack.is_from_raw_parts());
156        assert!(
157            !self.stack_size.is_zero(),
158            "pooling allocator not configured to enable fiber stack allocation \
159             (Self::allocate should have returned an error)"
160        );
161
162        if !self.async_stack_zeroing {
163            return 0;
164        }
165
166        let top = stack
167            .top()
168            .expect("fiber stack not allocated from the pool") as usize;
169
170        let base = self.mapping.as_ptr() as usize;
171        let len = self.mapping.len();
172        assert!(
173            top > base && top <= (base + len),
174            "fiber stack top pointer not in range"
175        );
176
177        // Remove the guard page from the size.
178        let stack_size = self.stack_size.checked_sub(self.page_size).expect(
179            "self.stack_size is host-page-aligned and is > 0,\
180             so it must be >= self.page_size",
181        );
182        let bottom_of_stack = top - stack_size.byte_count();
183        let start_of_stack = bottom_of_stack - self.page_size.byte_count();
184        assert!(start_of_stack >= base && start_of_stack < (base + len));
185        assert!((start_of_stack - base) % self.stack_size.byte_count() == 0);
186
187        // Manually zero the top of the stack to keep the pages resident in
188        // memory and avoid future page faults. Use the system to deallocate
189        // pages past this. This hopefully strikes a reasonable balance between:
190        //
191        // * memset for the whole range is probably expensive
192        // * madvise for the whole range incurs expensive future page faults
193        // * most threads probably don't use most of the stack anyway
194        let size_to_memset = stack_size.min(self.async_stack_keep_resident);
195        let rest = stack_size
196            .checked_sub(size_to_memset)
197            .expect("stack_size >= size_to_memset");
198
199        // SAFETY: this function's own contract requires that the stack is not
200        // in use so it's safe to pave over part of it with zero.
201        unsafe {
202            std::ptr::write_bytes(
203                (bottom_of_stack + rest.byte_count()) as *mut u8,
204                0,
205                size_to_memset.byte_count(),
206            );
207        }
208
209        // Use the system to reset remaining stack pages to zero.
210        decommit(bottom_of_stack as _, rest.byte_count());
211
212        size_to_memset.byte_count()
213    }
214
215    /// Deallocate a previously-allocated fiber.
216    ///
217    /// # Safety
218    ///
219    /// The fiber must have been allocated by this pool, must be in an allocated
220    /// state, and must never be used again.
221    ///
222    /// The caller must have already called `zero_stack` on the fiber stack and
223    /// flushed any enqueued decommits for this stack's memory.
224    pub unsafe fn deallocate(&self, stack: wasmtime_fiber::FiberStack, bytes_resident: usize) {
225        assert!(stack.is_from_raw_parts());
226
227        let top = stack
228            .top()
229            .expect("fiber stack not allocated from the pool") as usize;
230
231        let base = self.mapping.as_ptr() as usize;
232        let len = self.mapping.len();
233        assert!(
234            top > base && top <= (base + len),
235            "fiber stack top pointer not in range"
236        );
237
238        // Remove the guard page from the size
239        let stack_size = self.stack_size.byte_count() - self.page_size.byte_count();
240        let bottom_of_stack = top - stack_size;
241        let start_of_stack = bottom_of_stack - self.page_size.byte_count();
242        assert!(start_of_stack >= base && start_of_stack < (base + len));
243        assert!((start_of_stack - base) % self.stack_size.byte_count() == 0);
244
245        let index = (start_of_stack - base) / self.stack_size.byte_count();
246        assert!(index < self.max_stacks);
247        let index = u32::try_from(index).unwrap();
248
249        self.index_allocator.free(SlotId(index), bytes_resident);
250    }
251
252    pub fn unused_warm_slots(&self) -> u32 {
253        self.index_allocator.unused_warm_slots()
254    }
255
256    pub fn unused_bytes_resident(&self) -> Option<usize> {
257        if self.async_stack_zeroing {
258            Some(self.index_allocator.unused_bytes_resident())
259        } else {
260            None
261        }
262    }
263}
264
265#[cfg(all(test, unix, feature = "async", not(miri), not(asan)))]
266mod tests {
267    use super::*;
268    use crate::runtime::vm::InstanceLimits;
269
270    #[test]
271    fn test_stack_pool() -> Result<()> {
272        let config = PoolingInstanceAllocatorConfig {
273            limits: InstanceLimits {
274                total_stacks: 10,
275                ..Default::default()
276            },
277            stack_size: 1,
278            async_stack_zeroing: true,
279            ..PoolingInstanceAllocatorConfig::default()
280        };
281        let pool = StackPool::new(&config)?;
282
283        let native_page_size = crate::runtime::vm::host_page_size();
284        assert_eq!(pool.stack_size, 2 * native_page_size);
285        assert_eq!(pool.max_stacks, 10);
286        assert_eq!(pool.page_size, native_page_size);
287
288        assert_eq!(pool.index_allocator.testing_freelist(), []);
289
290        let base = pool.mapping.as_ptr() as usize;
291
292        let mut stacks = Vec::new();
293        for i in 0..10 {
294            let stack = pool.allocate().expect("allocation should succeed");
295            assert_eq!(
296                ((stack.top().unwrap() as usize - base) / pool.stack_size.byte_count()) - 1,
297                i
298            );
299            stacks.push(stack);
300        }
301
302        assert_eq!(pool.index_allocator.testing_freelist(), []);
303
304        assert!(pool.allocate().is_err(), "allocation should fail");
305
306        for stack in stacks {
307            unsafe {
308                pool.deallocate(stack, 0);
309            }
310        }
311
312        assert_eq!(
313            pool.index_allocator.testing_freelist(),
314            [
315                SlotId(0),
316                SlotId(1),
317                SlotId(2),
318                SlotId(3),
319                SlotId(4),
320                SlotId(5),
321                SlotId(6),
322                SlotId(7),
323                SlotId(8),
324                SlotId(9)
325            ],
326        );
327
328        Ok(())
329    }
330}