wasmtime_environ/
string_pool.rs1use crate::{
4 collections::{HashMap, String, Vec},
5 error::OutOfMemory,
6 prelude::*,
7};
8use core::{fmt, mem, num::NonZeroU32};
9use wasmtime_core::alloc::TryClone;
10
11#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
16pub struct Atom {
17 index: NonZeroU32,
18}
19
20#[derive(Default)]
33pub struct StringPool {
34 map: mem::ManuallyDrop<HashMap<&'static str, Atom>>,
37
38 strings: mem::ManuallyDrop<Vec<Box<str>>>,
41}
42
43impl Drop for StringPool {
44 fn drop(&mut self) {
45 unsafe {
50 mem::ManuallyDrop::drop(&mut self.map);
51 mem::ManuallyDrop::drop(&mut self.strings);
52 }
53 }
54}
55
56impl fmt::Debug for StringPool {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 struct Strings<'a>(&'a StringPool);
59 impl fmt::Debug for Strings<'_> {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 f.debug_map()
62 .entries(
63 self.0
64 .strings
65 .iter()
66 .enumerate()
67 .map(|(i, s)| (Atom::new(i), s)),
68 )
69 .finish()
70 }
71 }
72
73 f.debug_struct("StringPool")
74 .field("strings", &Strings(self))
75 .finish()
76 }
77}
78
79impl TryClone for StringPool {
80 fn try_clone(&self) -> Result<Self, OutOfMemory> {
81 Ok(StringPool {
82 map: self.map.try_clone()?,
83 strings: self.strings.try_clone()?,
84 })
85 }
86}
87
88impl TryClone for Atom {
89 fn try_clone(&self) -> Result<Self, OutOfMemory> {
90 Ok(*self)
91 }
92}
93
94impl core::ops::Index<Atom> for StringPool {
95 type Output = str;
96
97 #[inline]
98 #[track_caller]
99 fn index(&self, atom: Atom) -> &Self::Output {
100 self.get(atom).unwrap()
101 }
102}
103
104impl core::ops::Index<&'_ Atom> for StringPool {
106 type Output = str;
107
108 #[inline]
109 #[track_caller]
110 fn index(&self, atom: &Atom) -> &Self::Output {
111 self.get(*atom).unwrap()
112 }
113}
114
115impl serde::ser::Serialize for StringPool {
116 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
117 where
118 S: serde::Serializer,
119 {
120 serde::ser::Serialize::serialize(&*self.strings, serializer)
121 }
122}
123
124impl<'de> serde::de::Deserialize<'de> for StringPool {
125 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126 where
127 D: serde::Deserializer<'de>,
128 {
129 struct Visitor;
130 impl<'de> serde::de::Visitor<'de> for Visitor {
131 type Value = StringPool;
132
133 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
134 f.write_str("a `StringPool` sequence of strings")
135 }
136
137 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
138 where
139 A: serde::de::SeqAccess<'de>,
140 {
141 use serde::de::Error as _;
142
143 let mut pool = StringPool::new();
144
145 if let Some(len) = seq.size_hint() {
146 pool.map.reserve(len).map_err(|oom| A::Error::custom(oom))?;
147 pool.strings
148 .reserve(len)
149 .map_err(|oom| A::Error::custom(oom))?;
150 }
151
152 while let Some(s) = seq.next_element::<String>()? {
153 debug_assert_eq!(s.len(), s.capacity());
154 let s = s.into_boxed_str().map_err(|oom| A::Error::custom(oom))?;
155 if !pool.map.contains_key(&*s) {
156 pool.insert_new_boxed_str(s)
157 .map_err(|oom| A::Error::custom(oom))?;
158 }
159 }
160
161 Ok(pool)
162 }
163 }
164 deserializer.deserialize_seq(Visitor)
165 }
166}
167
168impl StringPool {
169 pub fn new() -> Self {
171 Self::default()
172 }
173
174 pub fn insert(&mut self, s: &str) -> Result<Atom, OutOfMemory> {
176 if let Some(atom) = self.map.get(s) {
177 return Ok(*atom);
178 }
179
180 self.map.reserve(1)?;
181 self.strings.reserve(1)?;
182
183 let mut owned = String::new();
184 owned.reserve_exact(s.len())?;
185 owned.push_str(s).expect("reserved capacity");
186 let owned = owned
187 .into_boxed_str()
188 .expect("reserved exact capacity, so shouldn't need to realloc");
189
190 self.insert_new_boxed_str(owned)
191 }
192
193 fn insert_new_boxed_str(&mut self, owned: Box<str>) -> Result<Atom, OutOfMemory> {
194 debug_assert!(!self.map.contains_key(&*owned));
195
196 let index = self.strings.len();
197 let atom = Atom::new(index);
198 self.strings.push(owned)?;
199
200 let s = unsafe { mem::transmute::<&str, &'static str>(&self.strings[index]) };
203
204 let old = self.map.insert(s, atom)?;
205 debug_assert!(old.is_none());
206
207 Ok(atom)
208 }
209
210 pub fn get_atom(&self, s: &str) -> Option<Atom> {
213 self.map.get(s).copied()
214 }
215
216 #[inline]
218 pub fn contains(&self, atom: Atom) -> bool {
219 atom.index() < self.strings.len()
220 }
221
222 #[inline]
225 pub fn get(&self, atom: Atom) -> Option<&str> {
226 if self.contains(atom) {
227 Some(&self.strings[atom.index()])
228 } else {
229 None
230 }
231 }
232
233 pub fn len(&self) -> usize {
235 self.strings.len()
236 }
237}
238
239impl Default for Atom {
240 #[inline]
241 fn default() -> Self {
242 Self {
243 index: NonZeroU32::MAX,
244 }
245 }
246}
247
248impl fmt::Debug for Atom {
249 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250 f.debug_struct("Atom")
251 .field("index", &self.index())
252 .finish()
253 }
254}
255
256impl crate::EntityRef for Atom {
258 fn new(index: usize) -> Self {
259 Atom::new(index)
260 }
261
262 fn index(self) -> usize {
263 Atom::index(&self)
264 }
265}
266
267impl serde::ser::Serialize for Atom {
268 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
269 where
270 S: serde::Serializer,
271 {
272 serde::ser::Serialize::serialize(&self.index, serializer)
273 }
274}
275
276impl<'de> serde::de::Deserialize<'de> for Atom {
277 fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
278 where
279 D: serde::Deserializer<'de>,
280 {
281 let index = serde::de::Deserialize::deserialize(deserializer)?;
282 Ok(Self { index })
283 }
284}
285
286impl Atom {
287 fn new(index: usize) -> Self {
288 assert!(index < usize::try_from(u32::MAX).unwrap());
289 let index = u32::try_from(index).unwrap();
290 let index = NonZeroU32::new(index + 1).unwrap();
291 Self { index }
292 }
293
294 pub fn index(&self) -> usize {
296 let index = self.index.get() - 1;
297 usize::try_from(index).unwrap()
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn basic() -> Result<()> {
307 let mut pool = StringPool::new();
308
309 let a = pool.insert("a")?;
310 assert_eq!(&pool[a], "a");
311 assert_eq!(pool.get_atom("a"), Some(a));
312
313 let a2 = pool.insert("a")?;
314 assert_eq!(a, a2);
315 assert_eq!(&pool[a2], "a");
316
317 let b = pool.insert("b")?;
318 assert_eq!(&pool[b], "b");
319 assert_ne!(a, b);
320 assert_eq!(pool.get_atom("b"), Some(b));
321
322 assert!(pool.get_atom("zzz").is_none());
323
324 let mut pool2 = StringPool::new();
325 let c = pool2.insert("c")?;
326 assert_eq!(&pool2[c], "c");
327 assert_eq!(a, c);
328 assert_eq!(&pool2[a], "c");
329 assert!(!pool2.contains(b));
330 assert!(pool2.get(b).is_none());
331
332 Ok(())
333 }
334
335 #[test]
336 fn stress() -> Result<()> {
337 let mut pool = StringPool::new();
338
339 let n = if cfg!(miri) { 100 } else { 10_000 };
340
341 for _ in 0..2 {
342 let atoms: Vec<_> = (0..n).map(|i| pool.insert(&i.to_string())).try_collect()?;
343
344 for atom in atoms {
345 assert!(pool.contains(atom));
346 assert_eq!(&pool[atom], atom.index().to_string());
347 }
348 }
349
350 Ok(())
351 }
352
353 #[test]
354 fn roundtrip_serialize_deserialize() -> Result<()> {
355 let mut pool = StringPool::new();
356 let a = pool.insert("a")?;
357 let b = pool.insert("b")?;
358 let c = pool.insert("c")?;
359
360 let bytes = postcard::to_allocvec(&(pool, a, b, c))?;
361 let (pool, a2, b2, c2) = postcard::from_bytes::<(StringPool, Atom, Atom, Atom)>(&bytes)?;
362
363 assert_eq!(&pool[a], "a");
364 assert_eq!(&pool[b], "b");
365 assert_eq!(&pool[c], "c");
366
367 assert_eq!(&pool[a2], "a");
368 assert_eq!(&pool[b2], "b");
369 assert_eq!(&pool[c2], "c");
370
371 Ok(())
372 }
373}