1use hashbrown::hash_table::HashTable;
8use std::hash::{Hash, Hasher};
9
10pub trait CtxEq<V1: ?Sized, V2: ?Sized> {
18 fn ctx_eq(&self, a: &V1, b: &V2) -> bool;
21}
22
23pub trait CtxHash<Value: ?Sized>: CtxEq<Value, Value> {
25 fn ctx_hash<H: Hasher>(&self, state: &mut H, value: &Value);
28}
29
30#[derive(Default)]
33pub struct NullCtx;
34
35impl<V: Eq + Hash> CtxEq<V, V> for NullCtx {
36 fn ctx_eq(&self, a: &V, b: &V) -> bool {
37 a.eq(b)
38 }
39}
40impl<V: Eq + Hash> CtxHash<V> for NullCtx {
41 fn ctx_hash<H: Hasher>(&self, state: &mut H, value: &V) {
42 value.hash(state);
43 }
44}
45
46struct BucketData<K, V> {
55 hash: u32,
56 k: K,
57 v: V,
58}
59
60pub struct CtxHashMap<K, V> {
62 raw: HashTable<BucketData<K, V>>,
63}
64
65impl<K, V> CtxHashMap<K, V> {
66 pub fn with_capacity(capacity: usize) -> Self {
69 Self {
70 raw: HashTable::with_capacity(capacity),
71 }
72 }
73}
74
75fn compute_hash<Ctx, K>(ctx: &Ctx, k: &K) -> u32
76where
77 Ctx: CtxHash<K>,
78{
79 let mut hasher = rustc_hash::FxHasher::default();
80 ctx.ctx_hash(&mut hasher, k);
81 hasher.finish() as u32
82}
83
84impl<K, V> CtxHashMap<K, V> {
85 pub fn insert<Ctx>(&mut self, k: K, v: V, ctx: &Ctx) -> Option<V>
88 where
89 Ctx: CtxEq<K, K> + CtxHash<K>,
90 {
91 let hash = compute_hash(ctx, &k);
92 match self.raw.find_mut(hash as u64, |bucket| {
93 hash == bucket.hash && ctx.ctx_eq(&bucket.k, &k)
94 }) {
95 Some(bucket) => Some(std::mem::replace(&mut bucket.v, v)),
96 None => {
97 let data = BucketData { hash, k, v };
98 self.raw
99 .insert_unique(hash as u64, data, |bucket| bucket.hash as u64);
100 None
101 }
102 }
103 }
104
105 pub fn get<'a, Q, Ctx>(&'a self, k: &Q, ctx: &Ctx) -> Option<&'a V>
107 where
108 Ctx: CtxEq<K, Q> + CtxHash<Q> + CtxHash<K>,
109 {
110 let hash = compute_hash(ctx, k);
111 self.raw
112 .find(hash as u64, |bucket| {
113 hash == bucket.hash && ctx.ctx_eq(&bucket.k, k)
114 })
115 .map(|bucket| &bucket.v)
116 }
117
118 pub fn entry<'a, Ctx>(&'a mut self, k: K, ctx: &Ctx) -> Entry<'a, K, V>
121 where
122 Ctx: CtxEq<K, K> + CtxHash<K>,
123 {
124 let hash = compute_hash(ctx, &k);
125 let raw = self.raw.entry(
126 hash as u64,
127 |bucket| hash == bucket.hash && ctx.ctx_eq(&bucket.k, &k),
128 |bucket| compute_hash(ctx, &bucket.k) as u64,
129 );
130 match raw {
131 hashbrown::hash_table::Entry::Occupied(o) => Entry::Occupied(OccupiedEntry { raw: o }),
132 hashbrown::hash_table::Entry::Vacant(v) => Entry::Vacant(VacantEntry {
133 hash,
134 key: k,
135 raw: v,
136 }),
137 }
138 }
139}
140
141pub enum Entry<'a, K, V> {
143 Occupied(OccupiedEntry<'a, K, V>),
144 Vacant(VacantEntry<'a, K, V>),
145}
146
147pub struct OccupiedEntry<'a, K, V> {
149 raw: hashbrown::hash_table::OccupiedEntry<'a, BucketData<K, V>>,
150}
151
152pub struct VacantEntry<'a, K, V> {
154 hash: u32,
155 key: K,
156 raw: hashbrown::hash_table::VacantEntry<'a, BucketData<K, V>>,
157}
158
159impl<'a, K, V> OccupiedEntry<'a, K, V> {
160 pub fn get(&self) -> &V {
162 &self.raw.get().v
163 }
164
165 pub fn get_mut(&mut self) -> &mut V {
167 &mut self.raw.get_mut().v
168 }
169}
170
171impl<'a, K, V> VacantEntry<'a, K, V> {
172 pub fn insert(self, v: V) {
174 self.raw.insert(BucketData {
175 hash: self.hash,
176 k: self.key,
177 v,
178 });
179 }
180}
181
182#[cfg(test)]
183mod test {
184 use super::*;
185
186 #[derive(Clone, Copy, Debug)]
187 struct Key {
188 index: u32,
189 }
190 struct Ctx {
191 vals: &'static [&'static str],
192 }
193 impl CtxEq<Key, Key> for Ctx {
194 fn ctx_eq(&self, a: &Key, b: &Key) -> bool {
195 self.vals[a.index as usize].eq(self.vals[b.index as usize])
196 }
197 }
198 impl CtxHash<Key> for Ctx {
199 fn ctx_hash<H: Hasher>(&self, state: &mut H, value: &Key) {
200 self.vals[value.index as usize].hash(state);
201 }
202 }
203
204 #[test]
205 fn test_basic() {
206 let ctx = Ctx {
207 vals: &["a", "b", "a"],
208 };
209
210 let k0 = Key { index: 0 };
211 let k1 = Key { index: 1 };
212 let k2 = Key { index: 2 };
213
214 assert!(ctx.ctx_eq(&k0, &k2));
215 assert!(!ctx.ctx_eq(&k0, &k1));
216 assert!(!ctx.ctx_eq(&k2, &k1));
217
218 let mut map: CtxHashMap<Key, u64> = CtxHashMap::with_capacity(4);
219 assert_eq!(map.insert(k0, 42, &ctx), None);
220 assert_eq!(map.insert(k2, 84, &ctx), Some(42));
221 assert_eq!(map.get(&k1, &ctx), None);
222 assert_eq!(*map.get(&k0, &ctx).unwrap(), 84);
223 }
224}