1use http::header::Entry;
2use http::{HeaderMap, HeaderName, HeaderValue};
3use std::fmt;
4use std::ops::Deref;
5use std::sync::Arc;
6use wasmtime::Result;
7
8#[derive(Debug, Clone)]
24pub struct FieldMap {
25 map: Arc<HeaderMap>,
26 limit: Limit,
27 size: usize,
28}
29
30#[derive(Debug, Clone)]
31enum Limit {
32 Mutable(usize),
33 Immutable,
34}
35
36impl Default for FieldMap {
37 fn default() -> Self {
38 Self::new_immutable(HeaderMap::default())
39 }
40}
41
42impl FieldMap {
43 pub fn new_immutable(map: HeaderMap) -> Self {
49 let size = Self::content_size(&map);
50 Self {
51 map: Arc::new(map),
52 size,
53 limit: Limit::Immutable,
54 }
55 }
56
57 pub fn new_mutable(limit: usize) -> Self {
62 Self {
63 map: Arc::new(HeaderMap::new()),
64 size: 0,
65 limit: Limit::Mutable(limit),
66 }
67 }
68
69 pub(crate) fn content_size(map: &HeaderMap) -> usize {
72 let mut sum = 0;
73 for key in map.keys() {
74 sum += header_name_size(key);
75 }
76 for value in map.values() {
77 sum += header_value_size(value);
78 }
79 sum
80 }
81
82 pub fn set(&mut self, key: HeaderName, values: Vec<HeaderValue>) -> Result<(), FieldMapError> {
90 let (map, limit, size) = self.mutable()?;
91 let key_size = header_name_size(&key);
92 let values_size = values.iter().map(header_value_size).sum::<usize>();
93 let mut values = values.into_iter();
94 let mut entry = match map.try_entry(key)? {
95 Entry::Vacant(e) => match values.next() {
96 Some(v) => {
97 update_size(size, limit, *size + values_size + key_size)?;
98 e.try_insert_entry(v)?
99 }
100 None => return Ok(()),
101 },
102 Entry::Occupied(mut e) => {
103 let prev_values_size = e.iter().map(header_value_size).sum::<usize>();
104 let _prev = match values.next() {
105 Some(v) => {
106 update_size(size, limit, *size - prev_values_size + values_size)?;
107 e.insert(v);
108 }
109 None => {
110 update_size(size, limit, *size - prev_values_size - key_size)?;
111 e.remove();
112 return Ok(());
113 }
114 };
115 e
116 }
117 };
118 for value in values {
119 entry.append(value);
120 }
121 Ok(())
122 }
123
124 pub fn remove_all(&mut self, key: HeaderName) -> Result<Vec<HeaderValue>, FieldMapError> {
128 let (map, _limit, size) = self.mutable()?;
129 match map.try_entry(key)? {
130 Entry::Vacant { .. } => Ok(Vec::new()),
131 Entry::Occupied(e) => {
132 let (name, value_drain) = e.remove_entry_mult();
133 let mut removed = header_name_size(&name);
134 let values = value_drain.collect::<Vec<_>>();
135 for v in values.iter() {
136 removed += header_value_size(v);
137 }
138 *size -= removed;
139 Ok(values)
140 }
141 }
142 }
143
144 fn mutable(&mut self) -> Result<(&mut HeaderMap, usize, &mut usize), FieldMapError> {
145 match self.limit {
146 Limit::Immutable => Err(FieldMapError::Immutable),
147 Limit::Mutable(limit) => Ok((Arc::make_mut(&mut self.map), limit, &mut self.size)),
148 }
149 }
150
151 pub fn append(&mut self, key: HeaderName, value: HeaderValue) -> Result<bool, FieldMapError> {
156 let (map, limit, size) = self.mutable()?;
157 let key_size = header_name_size(&key);
158 let val_size = header_value_size(&value);
159 let new_size = if !map.contains_key(&key) {
160 *size + key_size + val_size
161 } else {
162 *size + val_size
163 };
164 update_size(size, limit, new_size)?;
165 let already_present = map.try_append(key, value)?;
166 self.size = new_size;
167 Ok(already_present)
168 }
169
170 pub fn set_mutable(&mut self, limit: usize) {
173 self.limit = Limit::Mutable(limit);
174 }
175
176 pub fn set_immutable(&mut self) {
178 self.limit = Limit::Immutable;
179 }
180}
181
182fn header_name_size(name: &HeaderName) -> usize {
187 name.as_str().len() + size_of::<HeaderName>()
188}
189
190fn header_value_size(value: &HeaderValue) -> usize {
196 value.len() + size_of::<HeaderValue>()
197}
198
199fn update_size(size: &mut usize, limit: usize, new: usize) -> Result<(), FieldMapError> {
200 if new > limit {
201 Err(FieldMapError::TotalSizeTooBig)
202 } else {
203 *size = new;
204 Ok(())
205 }
206}
207
208impl Deref for FieldMap {
211 type Target = HeaderMap;
212
213 fn deref(&self) -> &HeaderMap {
214 &self.map
215 }
216}
217
218impl From<FieldMap> for HeaderMap {
219 fn from(map: FieldMap) -> Self {
220 Arc::unwrap_or_clone(map.map)
221 }
222}
223
224#[derive(Debug, PartialEq, Eq, Clone, Copy)]
226pub enum FieldMapError {
227 Immutable,
229 TooManyFields,
234 TotalSizeTooBig,
236 InvalidHeaderName,
238}
239
240impl fmt::Display for FieldMapError {
241 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
242 let s = match self {
243 FieldMapError::Immutable => "cannot mutate an immutable field map",
244 FieldMapError::TooManyFields => "too many fields in the field map",
245 FieldMapError::TotalSizeTooBig => "total size of fields exceeds limit",
246 FieldMapError::InvalidHeaderName => "invalid header name",
247 };
248 f.write_str(s)
249 }
250}
251
252impl std::error::Error for FieldMapError {}
253
254impl From<http::header::MaxSizeReached> for FieldMapError {
255 fn from(_: http::header::MaxSizeReached) -> Self {
256 Self::TooManyFields
257 }
258}
259
260impl From<http::header::InvalidHeaderName> for FieldMapError {
261 fn from(_: http::header::InvalidHeaderName) -> Self {
262 Self::InvalidHeaderName
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::{FieldMap, FieldMapError};
269
270 #[test]
271 fn test_immutable() {
272 let mut map = FieldMap::default();
273 assert_eq!(
274 map.set("foo".parse().unwrap(), vec!["bar".parse().unwrap()]),
275 Err(FieldMapError::Immutable)
276 );
277 assert_eq!(
278 map.append("foo".parse().unwrap(), "bar".parse().unwrap()),
279 Err(FieldMapError::Immutable)
280 );
281 assert_eq!(
282 map.remove_all("foo".parse().unwrap()),
283 Err(FieldMapError::Immutable)
284 );
285 }
286
287 #[test]
288 fn test_limits() {
289 let mut map = FieldMap::new_mutable(100);
290 loop {
291 match map.append("foo".parse().unwrap(), "bar".parse().unwrap()) {
292 Ok(_) => {}
293 Err(FieldMapError::TotalSizeTooBig) => break,
294 Err(e) => panic!("unexpected error: {e}"),
295 }
296 }
297
298 map = FieldMap::new_mutable(100);
299 for i in 0.. {
300 match map.set(
301 "foo".parse().unwrap(),
302 (0..i).map(|j| format!("bar{j}").parse().unwrap()).collect(),
303 ) {
304 Ok(_) => {}
305 Err(FieldMapError::TotalSizeTooBig) => break,
306 Err(e) => panic!("unexpected error: {e}"),
307 }
308 }
309
310 map = FieldMap::new_mutable(100);
311 for i in 0.. {
312 match map.set(
313 format!("foo{i}").parse().unwrap(),
314 vec!["bar".parse().unwrap()],
315 ) {
316 Ok(_) => {}
317 Err(FieldMapError::TotalSizeTooBig) => break,
318 Err(e) => panic!("unexpected error: {e}"),
319 }
320 }
321 }
322
323 #[test]
324 fn test_size() -> Result<(), FieldMapError> {
325 let mut map = FieldMap::new_mutable(2000);
326 let name: http::HeaderName = "foo".parse().unwrap();
327
328 map.append(name.clone(), "bar".parse().unwrap())?;
329 assert!(map.size > 0);
330 map.remove_all(name.clone())?;
331 assert_eq!(map.size, 0);
332
333 map.set(name.clone(), vec!["bar".parse().unwrap()])?;
334 assert!(map.size > 0);
335 map.remove_all(name.clone())?;
336 assert_eq!(map.size, 0);
337
338 map.set(name.clone(), vec![])?;
339 assert_eq!(map.size, 0);
340 map.set(name.clone(), vec!["bar".parse().unwrap()])?;
341 assert!(map.size > 0);
342 map.set(name.clone(), vec![])?;
343 assert_eq!(map.size, 0);
344
345 map.set(name.clone(), vec!["bar".parse().unwrap()])?;
346 assert!(map.size > 0);
347 map.set(
348 name.clone(),
349 vec!["bar".parse().unwrap(), "baz".parse().unwrap()],
350 )?;
351 assert!(map.size > 0);
352 map.remove_all(name.clone())?;
353 assert_eq!(map.size, 0);
354
355 Ok(())
356 }
357}