1use crate::{
4 compiled_blob::CompiledBlob,
5 memory::{BranchProtection, JITMemoryKind, JITMemoryProvider, SystemMemoryProvider},
6};
7use cranelift_codegen::binemit::Reloc;
8use cranelift_codegen::isa::{OwnedTargetIsa, TargetIsa};
9use cranelift_codegen::settings::Configurable;
10use cranelift_codegen::{ir, settings};
11use cranelift_control::ControlPlane;
12use cranelift_entity::SecondaryMap;
13use cranelift_module::{
14 DataDescription, DataId, FuncId, Init, Linkage, Module, ModuleDeclarations, ModuleError,
15 ModuleReloc, ModuleRelocTarget, ModuleResult,
16};
17use log::info;
18use std::cell::RefCell;
19use std::collections::HashMap;
20use std::ffi::CString;
21use std::io::Write;
22use target_lexicon::PointerWidth;
23
24const WRITABLE_DATA_ALIGNMENT: u64 = 0x8;
25const READONLY_DATA_ALIGNMENT: u64 = 0x1;
26
27pub struct JITBuilder {
29 isa: OwnedTargetIsa,
30 symbols: HashMap<String, SendWrapper<*const u8>>,
31 lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8> + Send>>,
32 libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
33 memory: Option<Box<dyn JITMemoryProvider + Send>>,
34}
35
36impl JITBuilder {
37 pub fn new(
44 libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
45 ) -> ModuleResult<Self> {
46 Self::with_flags(&[], libcall_names)
47 }
48
49 pub fn with_flags(
56 flags: &[(&str, &str)],
57 libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
58 ) -> ModuleResult<Self> {
59 let mut flag_builder = settings::builder();
60 for (name, value) in flags {
61 flag_builder.set(name, value)?;
62 }
63
64 flag_builder.set("use_colocated_libcalls", "false").unwrap();
68 flag_builder.set("is_pic", "false").unwrap();
69 let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| {
70 panic!("host machine is not supported: {msg}");
71 });
72 let isa = isa_builder.finish(settings::Flags::new(flag_builder))?;
73 Ok(Self::with_isa(isa, libcall_names))
74 }
75
76 pub fn with_isa(
87 isa: OwnedTargetIsa,
88 libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
89 ) -> Self {
90 let symbols = HashMap::new();
91 let lookup_symbols = vec![Box::new(lookup_with_dlsym) as Box<_>];
92 Self {
93 isa,
94 symbols,
95 lookup_symbols,
96 libcall_names,
97 memory: None,
98 }
99 }
100
101 pub fn symbol<K>(&mut self, name: K, ptr: *const u8) -> &mut Self
116 where
117 K: Into<String>,
118 {
119 self.symbols.insert(name.into(), SendWrapper(ptr));
120 self
121 }
122
123 pub fn symbols<It, K>(&mut self, symbols: It) -> &mut Self
127 where
128 It: IntoIterator<Item = (K, *const u8)>,
129 K: Into<String>,
130 {
131 for (name, ptr) in symbols {
132 self.symbols.insert(name.into(), SendWrapper(ptr));
133 }
134 self
135 }
136
137 pub fn symbol_lookup_fn(
142 &mut self,
143 symbol_lookup_fn: Box<dyn Fn(&str) -> Option<*const u8> + Send>,
144 ) -> &mut Self {
145 self.lookup_symbols.push(symbol_lookup_fn);
146 self
147 }
148
149 pub fn memory_provider(&mut self, provider: Box<dyn JITMemoryProvider + Send>) -> &mut Self {
153 self.memory = Some(provider);
154 self
155 }
156}
157
158#[derive(Copy, Clone)]
162struct SendWrapper<T>(T);
163unsafe impl<T> Send for SendWrapper<T> {}
164
165pub struct JITModule {
170 isa: OwnedTargetIsa,
171 symbols: RefCell<HashMap<String, SendWrapper<*const u8>>>,
172 lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8> + Send>>,
173 libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
174 memory: Box<dyn JITMemoryProvider + Send>,
175 declarations: ModuleDeclarations,
176 compiled_functions: SecondaryMap<FuncId, Option<CompiledBlob>>,
177 compiled_data_objects: SecondaryMap<DataId, Option<CompiledBlob>>,
178 code_ranges: Vec<(usize, usize, FuncId)>,
179 functions_to_finalize: Vec<FuncId>,
180 data_objects_to_finalize: Vec<DataId>,
181}
182
183impl JITModule {
184 pub unsafe fn free_memory(mut self) {
193 self.memory.free_memory();
194 }
195
196 fn lookup_symbol(&self, name: &str) -> Option<*const u8> {
197 match self.symbols.borrow_mut().entry(name.to_owned()) {
198 std::collections::hash_map::Entry::Occupied(occ) => Some(occ.get().0),
199 std::collections::hash_map::Entry::Vacant(vac) => {
200 let ptr = self
201 .lookup_symbols
202 .iter()
203 .rev() .find_map(|lookup| lookup(name));
205 if let Some(ptr) = ptr {
206 vac.insert(SendWrapper(ptr));
207 }
208 ptr
209 }
210 }
211 }
212
213 fn get_address(&self, name: &ModuleRelocTarget) -> *const u8 {
214 match name {
215 ModuleRelocTarget::User { .. } => {
216 let (name, linkage) = if ModuleDeclarations::is_function(name) {
217 let func_id = FuncId::from_name(name);
218 match &self.compiled_functions[func_id] {
219 Some(compiled) => return compiled.ptr(),
220 None => {
221 let decl = self.declarations.get_function_decl(func_id);
222 (&decl.name, decl.linkage)
223 }
224 }
225 } else {
226 let data_id = DataId::from_name(name);
227 match &self.compiled_data_objects[data_id] {
228 Some(compiled) => return compiled.ptr(),
229 None => {
230 let decl = self.declarations.get_data_decl(data_id);
231 (&decl.name, decl.linkage)
232 }
233 }
234 };
235 let name = name
236 .as_ref()
237 .expect("anonymous symbol must be defined locally");
238 if let Some(ptr) = self.lookup_symbol(name) {
239 ptr
240 } else if linkage == Linkage::Preemptible {
241 0 as *const u8
242 } else {
243 panic!("can't resolve symbol {name}");
244 }
245 }
246 ModuleRelocTarget::LibCall(libcall) => {
247 let sym = (self.libcall_names)(*libcall);
248 self.lookup_symbol(&sym)
249 .unwrap_or_else(|| panic!("can't resolve libcall {sym}"))
250 }
251 ModuleRelocTarget::FunctionOffset(func_id, offset) => {
252 match &self.compiled_functions[*func_id] {
253 Some(compiled) => return compiled.ptr().wrapping_add(*offset as usize),
254 None => todo!(),
255 }
256 }
257 name => panic!("invalid name {name:?}"),
258 }
259 }
260
261 pub fn get_finalized_function(&self, func_id: FuncId) -> *const u8 {
266 let info = &self.compiled_functions[func_id];
267 assert!(
268 !self.functions_to_finalize.iter().any(|x| *x == func_id),
269 "function not yet finalized"
270 );
271 info.as_ref()
272 .expect("function must be compiled before it can be finalized")
273 .ptr()
274 }
275
276 pub fn get_finalized_data(&self, data_id: DataId) -> (*const u8, usize) {
281 let info = &self.compiled_data_objects[data_id];
282 assert!(
283 !self.data_objects_to_finalize.iter().any(|x| *x == data_id),
284 "data object not yet finalized"
285 );
286 let compiled = info
287 .as_ref()
288 .expect("data object must be compiled before it can be finalized");
289
290 (compiled.ptr(), compiled.size())
291 }
292
293 fn record_function_for_perf(&self, ptr: *const u8, size: usize, name: &str) {
294 if cfg!(unix) && ::std::env::var_os("PERF_BUILDID_DIR").is_some() {
300 let mut map_file = ::std::fs::OpenOptions::new()
301 .create(true)
302 .append(true)
303 .open(format!("/tmp/perf-{}.map", ::std::process::id()))
304 .unwrap();
305
306 let _ = writeln!(map_file, "{:x} {:x} {}", ptr as usize, size, name);
307 }
308 }
309
310 pub fn finalize_definitions(&mut self) -> ModuleResult<()> {
319 for func in std::mem::take(&mut self.functions_to_finalize) {
320 let decl = self.declarations.get_function_decl(func);
321 assert!(decl.linkage.is_definable());
322 let func = self.compiled_functions[func]
323 .as_ref()
324 .expect("function must be compiled before it can be finalized");
325 func.perform_relocations(|name| self.get_address(name));
326 }
327
328 for data in std::mem::take(&mut self.data_objects_to_finalize) {
329 let decl = self.declarations.get_data_decl(data);
330 assert!(decl.linkage.is_definable());
331 let data = self.compiled_data_objects[data]
332 .as_ref()
333 .expect("data object must be compiled before it can be finalized");
334 data.perform_relocations(|name| self.get_address(name));
335 }
336
337 self.code_ranges
338 .sort_unstable_by_key(|(start, _end, _)| *start);
339
340 let branch_protection = if cfg!(target_arch = "aarch64") && use_bti(&self.isa.isa_flags()) {
342 BranchProtection::BTI
343 } else {
344 BranchProtection::None
345 };
346 self.memory.finalize(branch_protection)?;
347
348 Ok(())
349 }
350
351 pub fn new(builder: JITBuilder) -> Self {
353 assert!(
354 !builder.isa.flags().is_pic(),
355 "cranelift-jit needs is_pic=false"
356 );
357
358 let memory = builder
359 .memory
360 .unwrap_or_else(|| Box::new(SystemMemoryProvider::new()));
361 Self {
362 isa: builder.isa,
363 symbols: RefCell::new(builder.symbols),
364 lookup_symbols: builder.lookup_symbols,
365 libcall_names: builder.libcall_names,
366 memory,
367 declarations: ModuleDeclarations::default(),
368 compiled_functions: SecondaryMap::new(),
369 compiled_data_objects: SecondaryMap::new(),
370 code_ranges: Vec::new(),
371 functions_to_finalize: Vec::new(),
372 data_objects_to_finalize: Vec::new(),
373 }
374 }
375
376 #[cfg(feature = "wasmtime-unwinder")]
380 pub fn lookup_wasmtime_exception_data<'a>(
381 &'a self,
382 pc: usize,
383 ) -> Option<(usize, wasmtime_unwinder::ExceptionTable<'a>)> {
384 let idx = match self
386 .code_ranges
387 .binary_search_by_key(&pc, |(start, _end, _func)| *start)
388 {
389 Ok(exact_start_match) => Some(exact_start_match),
390 Err(least_upper_bound) if least_upper_bound > 0 => {
391 let last_range_before_pc = &self.code_ranges[least_upper_bound - 1];
392 if last_range_before_pc.0 <= pc && pc < last_range_before_pc.1 {
393 Some(least_upper_bound - 1)
394 } else {
395 None
396 }
397 }
398 _ => None,
399 }?;
400
401 let (start, _, func) = self.code_ranges[idx];
402
403 let data = self.compiled_functions[func]
407 .as_ref()
408 .unwrap()
409 .wasmtime_exception_data()?;
410 let exception_table = wasmtime_unwinder::ExceptionTable::parse(data).ok()?;
411 Some((start, exception_table))
412 }
413}
414
415impl Module for JITModule {
416 fn isa(&self) -> &dyn TargetIsa {
417 &*self.isa
418 }
419
420 fn declarations(&self) -> &ModuleDeclarations {
421 &self.declarations
422 }
423
424 fn declare_function(
425 &mut self,
426 name: &str,
427 linkage: Linkage,
428 signature: &ir::Signature,
429 ) -> ModuleResult<FuncId> {
430 let (id, _linkage) = self
431 .declarations
432 .declare_function(name, linkage, signature)?;
433 Ok(id)
434 }
435
436 fn declare_anonymous_function(&mut self, signature: &ir::Signature) -> ModuleResult<FuncId> {
437 let id = self.declarations.declare_anonymous_function(signature)?;
438 Ok(id)
439 }
440
441 fn declare_data(
442 &mut self,
443 name: &str,
444 linkage: Linkage,
445 writable: bool,
446 tls: bool,
447 ) -> ModuleResult<DataId> {
448 assert!(!tls, "JIT doesn't yet support TLS");
449 let (id, _linkage) = self
450 .declarations
451 .declare_data(name, linkage, writable, tls)?;
452 Ok(id)
453 }
454
455 fn declare_anonymous_data(&mut self, writable: bool, tls: bool) -> ModuleResult<DataId> {
456 assert!(!tls, "JIT doesn't yet support TLS");
457 let id = self.declarations.declare_anonymous_data(writable, tls)?;
458 Ok(id)
459 }
460
461 fn define_function_with_control_plane(
462 &mut self,
463 id: FuncId,
464 ctx: &mut cranelift_codegen::Context,
465 ctrl_plane: &mut ControlPlane,
466 ) -> ModuleResult<()> {
467 info!("defining function {}: {}", id, ctx.func.display());
468 let decl = self.declarations.get_function_decl(id);
469 if !decl.linkage.is_definable() {
470 return Err(ModuleError::InvalidImportDefinition(
471 decl.linkage_name(id).into_owned(),
472 ));
473 }
474
475 if !self.compiled_functions[id].is_none() {
476 return Err(ModuleError::DuplicateDefinition(
477 decl.linkage_name(id).into_owned(),
478 ));
479 }
480
481 let res = ctx.compile(self.isa(), ctrl_plane)?;
483 let alignment = res.buffer.alignment as u64;
484 let compiled_code = ctx.compiled_code().unwrap();
485
486 let align = alignment
487 .max(self.isa.function_alignment().minimum as u64)
488 .max(self.isa.symbol_alignment());
489
490 let relocs = compiled_code
491 .buffer
492 .relocs()
493 .iter()
494 .map(|reloc| ModuleReloc::from_mach_reloc(reloc, &ctx.func, id))
495 .collect();
496
497 #[cfg(feature = "wasmtime-unwinder")]
498 let wasmtime_exception_data = {
499 let mut exception_builder = wasmtime_unwinder::ExceptionTableBuilder::default();
500 exception_builder
501 .add_func(0, compiled_code.buffer.call_sites())
502 .map_err(|_| {
503 ModuleError::Compilation(cranelift_codegen::CodegenError::Unsupported(
504 "Invalid exception data".into(),
505 ))
506 })?;
507 Some(exception_builder.to_vec())
508 };
509
510 let blob = self.compiled_functions[id].insert(CompiledBlob::new(
511 &mut *self.memory,
512 compiled_code.code_buffer(),
513 align,
514 relocs,
515 #[cfg(feature = "wasmtime-unwinder")]
516 wasmtime_exception_data,
517 JITMemoryKind::Executable,
518 )?);
519 let (ptr, size) = (blob.ptr(), blob.size());
520 self.record_function_for_perf(ptr, size, &decl.linkage_name(id));
521
522 let range_start = ptr.addr();
523 let range_end = range_start + size;
524 self.code_ranges.push((range_start, range_end, id));
526
527 self.functions_to_finalize.push(id);
528
529 Ok(())
530 }
531
532 fn define_function_bytes(
533 &mut self,
534 id: FuncId,
535 alignment: u64,
536 bytes: &[u8],
537 relocs: &[ModuleReloc],
538 ) -> ModuleResult<()> {
539 info!("defining function {id} with bytes");
540 let decl = self.declarations.get_function_decl(id);
541 if !decl.linkage.is_definable() {
542 return Err(ModuleError::InvalidImportDefinition(
543 decl.linkage_name(id).into_owned(),
544 ));
545 }
546
547 if !self.compiled_functions[id].is_none() {
548 return Err(ModuleError::DuplicateDefinition(
549 decl.linkage_name(id).into_owned(),
550 ));
551 }
552
553 let align = alignment
554 .max(self.isa.function_alignment().minimum as u64)
555 .max(self.isa.symbol_alignment());
556
557 let blob = self.compiled_functions[id].insert(CompiledBlob::new(
558 &mut *self.memory,
559 bytes,
560 align,
561 relocs.to_owned(),
562 #[cfg(feature = "wasmtime-unwinder")]
563 None,
564 JITMemoryKind::Executable,
565 )?);
566 let (ptr, size) = (blob.ptr(), blob.size());
567 self.record_function_for_perf(ptr, size, &decl.linkage_name(id));
568
569 self.functions_to_finalize.push(id);
570
571 Ok(())
572 }
573
574 fn define_data(&mut self, id: DataId, data: &DataDescription) -> ModuleResult<()> {
575 let decl = self.declarations.get_data_decl(id);
576 if !decl.linkage.is_definable() {
577 return Err(ModuleError::InvalidImportDefinition(
578 decl.linkage_name(id).into_owned(),
579 ));
580 }
581
582 if !self.compiled_data_objects[id].is_none() {
583 return Err(ModuleError::DuplicateDefinition(
584 decl.linkage_name(id).into_owned(),
585 ));
586 }
587
588 assert!(!decl.tls, "JIT doesn't yet support TLS");
589
590 let &DataDescription {
591 ref init,
592 function_decls: _,
593 data_decls: _,
594 function_relocs: _,
595 data_relocs: _,
596 custom_segment_section: _,
597 align,
598 used: _,
599 } = data;
600
601 let (align, kind) = if decl.writable {
602 (
603 align.unwrap_or(WRITABLE_DATA_ALIGNMENT),
604 JITMemoryKind::Writable,
605 )
606 } else {
607 (
608 align.unwrap_or(READONLY_DATA_ALIGNMENT),
609 JITMemoryKind::ReadOnly,
610 )
611 };
612
613 let pointer_reloc = match self.isa.triple().pointer_width().unwrap() {
614 PointerWidth::U16 => panic!(),
615 PointerWidth::U32 => Reloc::Abs4,
616 PointerWidth::U64 => Reloc::Abs8,
617 };
618 let relocs = data.all_relocs(pointer_reloc).collect::<Vec<_>>();
619
620 self.compiled_data_objects[id] = Some(match *init {
621 Init::Uninitialized => {
622 panic!("data is not initialized yet");
623 }
624 Init::Zeros { size } => CompiledBlob::new_zeroed(
625 &mut *self.memory,
626 size.max(1),
627 align,
628 relocs,
629 #[cfg(feature = "wasmtime-unwinder")]
630 None,
631 kind,
632 )?,
633 Init::Bytes { ref contents } => CompiledBlob::new(
634 &mut *self.memory,
635 if contents.is_empty() {
636 &[0]
642 } else {
643 &contents[..]
644 },
645 align,
646 relocs,
647 #[cfg(feature = "wasmtime-unwinder")]
648 None,
649 kind,
650 )?,
651 });
652
653 self.data_objects_to_finalize.push(id);
654
655 Ok(())
656 }
657}
658
659#[cfg(not(windows))]
660fn lookup_with_dlsym(name: &str) -> Option<*const u8> {
661 let c_str = CString::new(name).unwrap();
662 let c_str_ptr = c_str.as_ptr();
663 let sym = unsafe { libc::dlsym(libc::RTLD_DEFAULT, c_str_ptr) };
664 if sym.is_null() {
665 None
666 } else {
667 Some(sym as *const u8)
668 }
669}
670
671#[cfg(windows)]
672fn lookup_with_dlsym(name: &str) -> Option<*const u8> {
673 use std::os::windows::io::RawHandle;
674 use std::ptr;
675 use windows_sys::Win32::Foundation::HMODULE;
676 use windows_sys::Win32::System::LibraryLoader;
677
678 const UCRTBASE: &[u8] = b"ucrtbase.dll\0";
679
680 let c_str = CString::new(name).unwrap();
681 let c_str_ptr = c_str.as_ptr();
682
683 unsafe {
684 let handles = [
685 ptr::null_mut(),
687 LibraryLoader::GetModuleHandleA(UCRTBASE.as_ptr()) as RawHandle,
689 ];
690
691 for handle in &handles {
692 let addr = LibraryLoader::GetProcAddress(*handle as HMODULE, c_str_ptr.cast());
693 match addr {
694 None => continue,
695 Some(addr) => return Some(addr as *const u8),
696 }
697 }
698
699 None
700 }
701}
702
703fn use_bti(isa_flags: &Vec<settings::Value>) -> bool {
704 isa_flags
705 .iter()
706 .find(|&f| f.name == "use_bti")
707 .map_or(false, |f| f.as_bool().unwrap_or(false))
708}
709
710const _ASSERT_JIT_MODULE_IS_SEND: () = {
711 const fn assert_is_send<T: Send>() {}
712 assert_is_send::<JITModule>();
713};