Skip to main content

wasmtime_wizer/
snapshot.rs

1use crate::InstanceState;
2use crate::info::ModuleContext;
3#[cfg(not(feature = "rayon"))]
4use crate::rayoff::{IntoParallelIterator, ParallelExtend};
5#[cfg(feature = "rayon")]
6use rayon::iter::{IntoParallelIterator, ParallelExtend, ParallelIterator};
7use std::convert::TryFrom;
8use std::ops::Range;
9
10/// The maximum number of data segments that we will emit. Most
11/// engines support more than this, but we want to leave some
12/// headroom.
13const MAX_DATA_SEGMENTS: usize = 10_000;
14
15/// A "snapshot" of Wasm state from its default value after having been initialized.
16pub struct Snapshot {
17    /// Maps global index to its initialized value.
18    ///
19    /// Note that this only tracks defined mutable globals, not all globals.
20    pub globals: Vec<(u32, SnapshotVal)>,
21
22    /// A new minimum size for each memory (in units of pages).
23    pub memory_mins: Vec<u64>,
24
25    /// Segments of non-zero memory.
26    pub data_segments: Vec<DataSegment>,
27}
28
29/// A value from a snapshot, currently a subset of wasm types that aren't
30/// reference types.
31#[expect(missing_docs, reason = "self-describing variants")]
32pub enum SnapshotVal {
33    I32(i32),
34    I64(i64),
35    F32(u32),
36    F64(u64),
37    V128(u128),
38}
39
40/// A data segment initializer for a memory.
41#[derive(Clone)]
42pub struct DataSegment {
43    /// The index of this data segment's memory.
44    pub memory_index: u32,
45
46    /// This data segment's initialized memory that it originated from.
47    pub data: Vec<u8>,
48
49    /// The offset within the memory that `data` should be copied to.
50    pub offset: u64,
51
52    /// Whether or not `memory_index` is a 64-bit memory.
53    pub is64: bool,
54}
55
56/// Snapshot the given instance's globals, memories, and instances from the Wasm
57/// defaults.
58//
59// TODO: when we support reference types, we will have to snapshot tables.
60pub async fn snapshot(module: &ModuleContext<'_>, ctx: &mut impl InstanceState) -> Snapshot {
61    log::debug!("Snapshotting the initialized state");
62
63    let globals = snapshot_globals(module, ctx).await;
64    let (memory_mins, data_segments) = snapshot_memories(module, ctx).await;
65
66    Snapshot {
67        globals,
68        memory_mins,
69        data_segments,
70    }
71}
72
73/// Get the initialized values of all globals.
74async fn snapshot_globals(
75    module: &ModuleContext<'_>,
76    ctx: &mut impl InstanceState,
77) -> Vec<(u32, SnapshotVal)> {
78    log::debug!("Snapshotting global values");
79
80    let mut ret = Vec::new();
81    for (i, ty, name) in module.defined_globals() {
82        if let Some(name) = name {
83            let val = ctx.global_get(name, ty.content_type).await;
84            ret.push((i, val));
85        }
86    }
87    ret
88}
89
90#[derive(Clone)]
91struct DataSegmentRange {
92    memory_index: u32,
93    range: Range<usize>,
94}
95
96impl DataSegmentRange {
97    /// What is the gap between two consecutive data segments?
98    ///
99    /// `self` must be in front of `other` and they must not overlap with each
100    /// other.
101    fn gap(&self, other: &Self) -> usize {
102        debug_assert_eq!(self.memory_index, other.memory_index);
103        debug_assert!(self.range.end <= other.range.start);
104        other.range.start - self.range.end
105    }
106
107    /// Merge two consecutive data segments.
108    ///
109    /// `self` must be in front of `other` and they must not overlap with each
110    /// other.
111    fn merge(&mut self, other: &Self) {
112        debug_assert_eq!(self.memory_index, other.memory_index);
113        debug_assert!(self.range.end <= other.range.start);
114        self.range.end = other.range.end;
115    }
116}
117
118/// Find the initialized minimum page size of each memory, as well as all
119/// regions of non-zero memory.
120async fn snapshot_memories(
121    module: &ModuleContext<'_>,
122    instance: &mut impl InstanceState,
123) -> (Vec<u64>, Vec<DataSegment>) {
124    log::debug!("Snapshotting memories");
125
126    // Find and record non-zero regions of memory (in parallel).
127    let mut memory_mins = vec![];
128    let mut data_segments = vec![];
129    let iter = module
130        .defined_memories()
131        .zip(module.defined_memory_exports.as_ref().unwrap());
132    for ((memory_index, ty), name) in iter {
133        instance
134            .memory_contents(&name, |memory| {
135                let page_size = 1 << ty.page_size_log2.unwrap_or(16);
136                let num_wasm_pages = memory.len() / page_size;
137                memory_mins.push(num_wasm_pages as u64);
138
139                let memory_data = &memory[..];
140
141                // Consider each Wasm page in parallel. Create data segments for each
142                // region of non-zero memory.
143                data_segments.par_extend((0..num_wasm_pages).into_par_iter().flat_map(|i| {
144                    let page_end = (i + 1) * page_size;
145                    let mut start = i * page_size;
146                    let mut segments = vec![];
147                    while start < page_end {
148                        let nonzero = match memory_data[start..page_end]
149                            .iter()
150                            .position(|byte| *byte != 0)
151                        {
152                            None => break,
153                            Some(i) => i,
154                        };
155                        start += nonzero;
156                        let end = memory_data[start..page_end]
157                            .iter()
158                            .position(|byte| *byte == 0)
159                            .map_or(page_end, |zero| start + zero);
160                        segments.push(DataSegmentRange {
161                            memory_index,
162                            range: start..end,
163                        });
164                        start = end;
165                    }
166                    segments
167                }));
168            })
169            .await;
170    }
171
172    if data_segments.is_empty() {
173        return (memory_mins, Vec::new());
174    }
175
176    // Sort data segments to enforce determinism in the face of the
177    // parallelism above.
178    data_segments.sort_by_key(|s| (s.memory_index, s.range.start));
179
180    // Merge any contiguous segments (caused by spanning a Wasm page boundary,
181    // and therefore created in separate logical threads above) or pages that
182    // are within four bytes of each other. Four because this is the minimum
183    // overhead of defining a new active data segment: one for the memory index
184    // LEB, two for the memory offset init expression (one for the `i32.const`
185    // opcode and another for the constant immediate LEB), and finally one for
186    // the data length LEB).
187    const MIN_ACTIVE_SEGMENT_OVERHEAD: usize = 4;
188    let mut merged_data_segments = Vec::with_capacity(data_segments.len());
189    merged_data_segments.push(data_segments[0].clone());
190    for b in &data_segments[1..] {
191        let a = merged_data_segments.last_mut().unwrap();
192
193        // Only merge segments for the same memory.
194        if a.memory_index != b.memory_index {
195            merged_data_segments.push(b.clone());
196            continue;
197        }
198
199        // Only merge segments if they are contiguous or if it is definitely
200        // more size efficient than leaving them apart.
201        let gap = a.gap(b);
202        if gap > MIN_ACTIVE_SEGMENT_OVERHEAD {
203            merged_data_segments.push(b.clone());
204            continue;
205        }
206
207        // Okay, merge them together into `a` (so that the next iteration can
208        // merge it with its predecessor) and then omit `b`!
209        a.merge(b);
210    }
211
212    remove_excess_segments(&mut merged_data_segments);
213
214    // With the final set of data segments now extract the actual data of each
215    // memory, copying it into a `DataSegment`, to return the final list of
216    // segments.
217    //
218    // Here the memories are iterated over again and, in tandem, the
219    // `merged_data_segments` list is traversed to extract a `DataSegment` for
220    // each range that `merged_data_segments` indicates. This relies on
221    // `merged_data_segments` being a sorted list by `memory_index` at least.
222    let mut final_data_segments = Vec::with_capacity(merged_data_segments.len());
223    let mut merged = merged_data_segments.iter().peekable();
224    let iter = module
225        .defined_memories()
226        .zip(module.defined_memory_exports.as_ref().unwrap());
227    for ((memory_index, ty), name) in iter {
228        instance
229            .memory_contents(&name, |memory| {
230                while let Some(segment) = merged.next_if(|s| s.memory_index == memory_index) {
231                    final_data_segments.push(DataSegment {
232                        memory_index,
233                        data: memory[segment.range.clone()].to_vec(),
234                        offset: segment.range.start.try_into().unwrap(),
235                        is64: ty.memory64,
236                    });
237                }
238            })
239            .await;
240    }
241    assert!(merged.next().is_none());
242
243    (memory_mins, final_data_segments)
244}
245
246/// Engines apply a limit on how many segments a module may contain, and Wizer
247/// can run afoul of it. When that happens, we need to merge data segments
248/// together until our number of data segments fits within the limit.
249fn remove_excess_segments(merged_data_segments: &mut Vec<DataSegmentRange>) {
250    if merged_data_segments.len() < MAX_DATA_SEGMENTS {
251        return;
252    }
253
254    // We need to remove `excess` number of data segments.
255    let excess = merged_data_segments.len() - MAX_DATA_SEGMENTS;
256
257    #[derive(Clone, Copy, PartialEq, Eq)]
258    struct GapIndex {
259        gap: u32,
260        // Use a `u32` instead of `usize` to fit `GapIndex` within a word on
261        // 64-bit systems, using less memory.
262        index: u32,
263    }
264
265    // Find the gaps between the start of one segment and the next (if they are
266    // both in the same memory). We will merge the `excess` segments with the
267    // smallest gaps together. Because they are the smallest gaps, this will
268    // bloat the size of our data segment the least.
269    let mut smallest_gaps = Vec::with_capacity(merged_data_segments.len() - 1);
270    for (index, w) in merged_data_segments.windows(2).enumerate() {
271        if w[0].memory_index != w[1].memory_index {
272            continue;
273        }
274        let gap = match u32::try_from(w[0].gap(&w[1])) {
275            Ok(gap) => gap,
276            // If the gap is larger than 4G then don't consider these two data
277            // segments for merging and assume there's enough other data
278            // segments close enough together to still consider for merging to
279            // get under the limit.
280            Err(_) => continue,
281        };
282        let index = u32::try_from(index).unwrap();
283        smallest_gaps.push(GapIndex { gap, index });
284    }
285    smallest_gaps.sort_unstable_by_key(|g| g.gap);
286    smallest_gaps.truncate(excess);
287
288    // Now merge the chosen segments together in reverse index order so that
289    // merging two segments doesn't mess up the index of the next segments we
290    // will to merge.
291    smallest_gaps.sort_unstable_by(|a, b| a.index.cmp(&b.index).reverse());
292    for GapIndex { index, .. } in smallest_gaps {
293        let index = usize::try_from(index).unwrap();
294        let [a, b] = merged_data_segments
295            .get_disjoint_mut([index, index + 1])
296            .unwrap();
297        a.merge(b);
298
299        // Okay to use `swap_remove` here because, even though it makes
300        // `merged_data_segments` unsorted, the segments are still sorted within
301        // the range `0..index` and future iterations will only operate within
302        // that subregion because we are iterating over largest to smallest
303        // indices.
304        merged_data_segments.swap_remove(index + 1);
305    }
306
307    // Finally, sort the data segments again so that our output is
308    // deterministic.
309    merged_data_segments.sort_by_key(|s| (s.memory_index, s.range.start));
310}