Skip to main content

wasmtime_environ/
string_pool.rs

1//! Simple string interning.
2
3use crate::{
4    collections::{HashMap, String, Vec},
5    error::OutOfMemory,
6    prelude::*,
7};
8use core::{fmt, mem, num::NonZeroU32};
9use wasmtime_core::alloc::TryClone;
10
11/// An interned string associated with a particular string in a `StringPool`.
12///
13/// Allows for $O(1)$ equality tests, $O(1)$ hashing, and $O(1)$
14/// arbitrary-but-stable ordering.
15#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
16pub struct Atom {
17    index: NonZeroU32,
18}
19
20/// A pool of interned strings.
21///
22/// Insert new strings with [`StringPool::insert`] to get an `Atom` that is
23/// unique per string within the context of the associated pool.
24///
25/// Once you have interned a string into the pool and have its `Atom`, you can
26/// get the interned string slice via `&pool[atom]` or `pool.get(atom)`.
27///
28/// In general, there are no correctness protections against indexing into a
29/// different `StringPool` from the one that the `Atom` was not allocated
30/// inside. Doing so is memory safe but may panic or otherwise return incorrect
31/// results.
32#[derive(Default)]
33pub struct StringPool {
34    /// A map from each string in this pool (as an unsafe borrow from
35    /// `self.strings`) to its `Atom`.
36    map: mem::ManuallyDrop<HashMap<&'static str, Atom>>,
37
38    /// Strings in this pool. These must never be mutated or reallocated once
39    /// inserted.
40    strings: mem::ManuallyDrop<Vec<Box<str>>>,
41}
42
43impl Drop for StringPool {
44    fn drop(&mut self) {
45        // Ensure that `self.map` is dropped before `self.strings`, since
46        // `self.map` borrows from `self.strings`.
47        //
48        // Safety: Neither field will be used again.
49        unsafe {
50            mem::ManuallyDrop::drop(&mut self.map);
51            mem::ManuallyDrop::drop(&mut self.strings);
52        }
53    }
54}
55
56impl fmt::Debug for StringPool {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        struct Strings<'a>(&'a StringPool);
59        impl fmt::Debug for Strings<'_> {
60            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61                f.debug_map()
62                    .entries(
63                        self.0
64                            .strings
65                            .iter()
66                            .enumerate()
67                            .map(|(i, s)| (Atom::new(i), s)),
68                    )
69                    .finish()
70            }
71        }
72
73        f.debug_struct("StringPool")
74            .field("strings", &Strings(self))
75            .finish()
76    }
77}
78
79impl TryClone for StringPool {
80    fn try_clone(&self) -> Result<Self, OutOfMemory> {
81        Ok(StringPool {
82            map: self.map.try_clone()?,
83            strings: self.strings.try_clone()?,
84        })
85    }
86}
87
88impl TryClone for Atom {
89    fn try_clone(&self) -> Result<Self, OutOfMemory> {
90        Ok(*self)
91    }
92}
93
94impl core::ops::Index<Atom> for StringPool {
95    type Output = str;
96
97    #[inline]
98    #[track_caller]
99    fn index(&self, atom: Atom) -> &Self::Output {
100        self.get(atom).unwrap()
101    }
102}
103
104// For convenience, to avoid `*atom` noise at call sites.
105impl core::ops::Index<&'_ Atom> for StringPool {
106    type Output = str;
107
108    #[inline]
109    #[track_caller]
110    fn index(&self, atom: &Atom) -> &Self::Output {
111        self.get(*atom).unwrap()
112    }
113}
114
115impl serde::ser::Serialize for StringPool {
116    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
117    where
118        S: serde::Serializer,
119    {
120        serde::ser::Serialize::serialize(&*self.strings, serializer)
121    }
122}
123
124impl<'de> serde::de::Deserialize<'de> for StringPool {
125    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126    where
127        D: serde::Deserializer<'de>,
128    {
129        struct Visitor;
130        impl<'de> serde::de::Visitor<'de> for Visitor {
131            type Value = StringPool;
132
133            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
134                f.write_str("a `StringPool` sequence of strings")
135            }
136
137            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
138            where
139                A: serde::de::SeqAccess<'de>,
140            {
141                use serde::de::Error as _;
142
143                let mut pool = StringPool::new();
144
145                if let Some(len) = seq.size_hint() {
146                    pool.map.reserve(len).map_err(|oom| A::Error::custom(oom))?;
147                    pool.strings
148                        .reserve(len)
149                        .map_err(|oom| A::Error::custom(oom))?;
150                }
151
152                while let Some(s) = seq.next_element::<String>()? {
153                    debug_assert_eq!(s.len(), s.capacity());
154                    let s = s.into_boxed_str().map_err(|oom| A::Error::custom(oom))?;
155                    if !pool.map.contains_key(&*s) {
156                        pool.insert_new_boxed_str(s)
157                            .map_err(|oom| A::Error::custom(oom))?;
158                    }
159                }
160
161                Ok(pool)
162            }
163        }
164        deserializer.deserialize_seq(Visitor)
165    }
166}
167
168impl StringPool {
169    /// Create a new, empty pool.
170    pub fn new() -> Self {
171        Self::default()
172    }
173
174    /// Insert a new string into this pool.
175    pub fn insert(&mut self, s: &str) -> Result<Atom, OutOfMemory> {
176        if let Some(atom) = self.map.get(s) {
177            return Ok(*atom);
178        }
179
180        self.map.reserve(1)?;
181        self.strings.reserve(1)?;
182
183        let mut owned = String::new();
184        owned.reserve_exact(s.len())?;
185        owned.push_str(s).expect("reserved capacity");
186        let owned = owned
187            .into_boxed_str()
188            .expect("reserved exact capacity, so shouldn't need to realloc");
189
190        self.insert_new_boxed_str(owned)
191    }
192
193    fn insert_new_boxed_str(&mut self, owned: Box<str>) -> Result<Atom, OutOfMemory> {
194        debug_assert!(!self.map.contains_key(&*owned));
195
196        let index = self.strings.len();
197        let atom = Atom::new(index);
198        self.strings.push(owned)?;
199
200        // SAFETY: We never expose this borrow and never mutate or reallocate
201        // strings once inserted into the pool.
202        let s = unsafe { mem::transmute::<&str, &'static str>(&self.strings[index]) };
203
204        let old = self.map.insert(s, atom)?;
205        debug_assert!(old.is_none());
206
207        Ok(atom)
208    }
209
210    /// Get the `Atom` for the given string, if it has already been inserted
211    /// into this pool.
212    pub fn get_atom(&self, s: &str) -> Option<Atom> {
213        self.map.get(s).copied()
214    }
215
216    /// Does this pool contain the given `atom`?
217    #[inline]
218    pub fn contains(&self, atom: Atom) -> bool {
219        atom.index() < self.strings.len()
220    }
221
222    /// Get the string associated with the given `atom`, if the pool contains
223    /// the atom.
224    #[inline]
225    pub fn get(&self, atom: Atom) -> Option<&str> {
226        if self.contains(atom) {
227            Some(&self.strings[atom.index()])
228        } else {
229            None
230        }
231    }
232
233    /// Get the number of strings in this pool.
234    pub fn len(&self) -> usize {
235        self.strings.len()
236    }
237}
238
239impl Default for Atom {
240    #[inline]
241    fn default() -> Self {
242        Self {
243            index: NonZeroU32::MAX,
244        }
245    }
246}
247
248impl fmt::Debug for Atom {
249    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250        f.debug_struct("Atom")
251            .field("index", &self.index())
252            .finish()
253    }
254}
255
256// Allow using `Atom` in `SecondaryMap`s.
257impl crate::EntityRef for Atom {
258    fn new(index: usize) -> Self {
259        Atom::new(index)
260    }
261
262    fn index(self) -> usize {
263        Atom::index(&self)
264    }
265}
266
267impl serde::ser::Serialize for Atom {
268    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
269    where
270        S: serde::Serializer,
271    {
272        serde::ser::Serialize::serialize(&self.index, serializer)
273    }
274}
275
276impl<'de> serde::de::Deserialize<'de> for Atom {
277    fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
278    where
279        D: serde::Deserializer<'de>,
280    {
281        let index = serde::de::Deserialize::deserialize(deserializer)?;
282        Ok(Self { index })
283    }
284}
285
286impl Atom {
287    fn new(index: usize) -> Self {
288        assert!(index < usize::try_from(u32::MAX).unwrap());
289        let index = u32::try_from(index).unwrap();
290        let index = NonZeroU32::new(index + 1).unwrap();
291        Self { index }
292    }
293
294    /// Get this atom's index in its pool.
295    pub fn index(&self) -> usize {
296        let index = self.index.get() - 1;
297        usize::try_from(index).unwrap()
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn basic() -> Result<()> {
307        let mut pool = StringPool::new();
308
309        let a = pool.insert("a")?;
310        assert_eq!(&pool[a], "a");
311        assert_eq!(pool.get_atom("a"), Some(a));
312
313        let a2 = pool.insert("a")?;
314        assert_eq!(a, a2);
315        assert_eq!(&pool[a2], "a");
316
317        let b = pool.insert("b")?;
318        assert_eq!(&pool[b], "b");
319        assert_ne!(a, b);
320        assert_eq!(pool.get_atom("b"), Some(b));
321
322        assert!(pool.get_atom("zzz").is_none());
323
324        let mut pool2 = StringPool::new();
325        let c = pool2.insert("c")?;
326        assert_eq!(&pool2[c], "c");
327        assert_eq!(a, c);
328        assert_eq!(&pool2[a], "c");
329        assert!(!pool2.contains(b));
330        assert!(pool2.get(b).is_none());
331
332        Ok(())
333    }
334
335    #[test]
336    fn stress() -> Result<()> {
337        let mut pool = StringPool::new();
338
339        let n = if cfg!(miri) { 100 } else { 10_000 };
340
341        for _ in 0..2 {
342            let atoms: Vec<_> = (0..n).map(|i| pool.insert(&i.to_string())).try_collect()?;
343
344            for atom in atoms {
345                assert!(pool.contains(atom));
346                assert_eq!(&pool[atom], atom.index().to_string());
347            }
348        }
349
350        Ok(())
351    }
352
353    #[test]
354    fn roundtrip_serialize_deserialize() -> Result<()> {
355        let mut pool = StringPool::new();
356        let a = pool.insert("a")?;
357        let b = pool.insert("b")?;
358        let c = pool.insert("c")?;
359
360        let bytes = postcard::to_allocvec(&(pool, a, b, c))?;
361        let (pool, a2, b2, c2) = postcard::from_bytes::<(StringPool, Atom, Atom, Atom)>(&bytes)?;
362
363        assert_eq!(&pool[a], "a");
364        assert_eq!(&pool[b], "b");
365        assert_eq!(&pool[c], "c");
366
367        assert_eq!(&pool[a2], "a");
368        assert_eq!(&pool[b2], "b");
369        assert_eq!(&pool[c2], "c");
370
371        Ok(())
372    }
373}