1use crate::{
4 compiled_blob::CompiledBlob,
5 memory::{BranchProtection, JITMemoryProvider, SystemMemoryProvider},
6};
7use cranelift_codegen::binemit::Reloc;
8use cranelift_codegen::isa::{OwnedTargetIsa, TargetIsa};
9use cranelift_codegen::settings::Configurable;
10use cranelift_codegen::{ir, settings};
11use cranelift_control::ControlPlane;
12use cranelift_entity::SecondaryMap;
13use cranelift_module::{
14 DataDescription, DataId, FuncId, Init, Linkage, Module, ModuleDeclarations, ModuleError,
15 ModuleReloc, ModuleRelocTarget, ModuleResult,
16};
17use log::info;
18use std::cell::RefCell;
19use std::collections::HashMap;
20use std::ffi::CString;
21use std::io::Write;
22use std::ptr;
23use target_lexicon::PointerWidth;
24
25const WRITABLE_DATA_ALIGNMENT: u64 = 0x8;
26const READONLY_DATA_ALIGNMENT: u64 = 0x1;
27
28pub struct JITBuilder {
30 isa: OwnedTargetIsa,
31 symbols: HashMap<String, SendWrapper<*const u8>>,
32 lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8> + Send>>,
33 libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
34 memory: Option<Box<dyn JITMemoryProvider>>,
35}
36
37impl JITBuilder {
38 pub fn new(
45 libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
46 ) -> ModuleResult<Self> {
47 Self::with_flags(&[], libcall_names)
48 }
49
50 pub fn with_flags(
57 flags: &[(&str, &str)],
58 libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
59 ) -> ModuleResult<Self> {
60 let mut flag_builder = settings::builder();
61 for (name, value) in flags {
62 flag_builder.set(name, value)?;
63 }
64
65 flag_builder.set("use_colocated_libcalls", "false").unwrap();
69 flag_builder.set("is_pic", "false").unwrap();
70 let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| {
71 panic!("host machine is not supported: {msg}");
72 });
73 let isa = isa_builder.finish(settings::Flags::new(flag_builder))?;
74 Ok(Self::with_isa(isa, libcall_names))
75 }
76
77 pub fn with_isa(
88 isa: OwnedTargetIsa,
89 libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
90 ) -> Self {
91 let symbols = HashMap::new();
92 let lookup_symbols = vec![Box::new(lookup_with_dlsym) as Box<_>];
93 Self {
94 isa,
95 symbols,
96 lookup_symbols,
97 libcall_names,
98 memory: None,
99 }
100 }
101
102 pub fn symbol<K>(&mut self, name: K, ptr: *const u8) -> &mut Self
117 where
118 K: Into<String>,
119 {
120 self.symbols.insert(name.into(), SendWrapper(ptr));
121 self
122 }
123
124 pub fn symbols<It, K>(&mut self, symbols: It) -> &mut Self
128 where
129 It: IntoIterator<Item = (K, *const u8)>,
130 K: Into<String>,
131 {
132 for (name, ptr) in symbols {
133 self.symbols.insert(name.into(), SendWrapper(ptr));
134 }
135 self
136 }
137
138 pub fn symbol_lookup_fn(
143 &mut self,
144 symbol_lookup_fn: Box<dyn Fn(&str) -> Option<*const u8> + Send>,
145 ) -> &mut Self {
146 self.lookup_symbols.push(symbol_lookup_fn);
147 self
148 }
149
150 pub fn memory_provider(&mut self, provider: Box<dyn JITMemoryProvider>) -> &mut Self {
154 self.memory = Some(provider);
155 self
156 }
157}
158
159#[derive(Copy, Clone)]
163struct SendWrapper<T>(T);
164unsafe impl<T> Send for SendWrapper<T> {}
165
166pub struct JITModule {
171 isa: OwnedTargetIsa,
172 symbols: RefCell<HashMap<String, SendWrapper<*const u8>>>,
173 lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8> + Send>>,
174 libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
175 memory: Box<dyn JITMemoryProvider>,
176 declarations: ModuleDeclarations,
177 compiled_functions: SecondaryMap<FuncId, Option<CompiledBlob>>,
178 compiled_data_objects: SecondaryMap<DataId, Option<CompiledBlob>>,
179 code_ranges: Vec<(usize, usize, FuncId)>,
180 functions_to_finalize: Vec<FuncId>,
181 data_objects_to_finalize: Vec<DataId>,
182}
183
184impl JITModule {
185 pub unsafe fn free_memory(mut self) {
194 self.memory.free_memory();
195 }
196
197 fn lookup_symbol(&self, name: &str) -> Option<*const u8> {
198 match self.symbols.borrow_mut().entry(name.to_owned()) {
199 std::collections::hash_map::Entry::Occupied(occ) => Some(occ.get().0),
200 std::collections::hash_map::Entry::Vacant(vac) => {
201 let ptr = self
202 .lookup_symbols
203 .iter()
204 .rev() .find_map(|lookup| lookup(name));
206 if let Some(ptr) = ptr {
207 vac.insert(SendWrapper(ptr));
208 }
209 ptr
210 }
211 }
212 }
213
214 fn get_address(&self, name: &ModuleRelocTarget) -> *const u8 {
215 match *name {
216 ModuleRelocTarget::User { .. } => {
217 let (name, linkage) = if ModuleDeclarations::is_function(name) {
218 let func_id = FuncId::from_name(name);
219 match &self.compiled_functions[func_id] {
220 Some(compiled) => return compiled.ptr,
221 None => {
222 let decl = self.declarations.get_function_decl(func_id);
223 (&decl.name, decl.linkage)
224 }
225 }
226 } else {
227 let data_id = DataId::from_name(name);
228 match &self.compiled_data_objects[data_id] {
229 Some(compiled) => return compiled.ptr,
230 None => {
231 let decl = self.declarations.get_data_decl(data_id);
232 (&decl.name, decl.linkage)
233 }
234 }
235 };
236 let name = name
237 .as_ref()
238 .expect("anonymous symbol must be defined locally");
239 if let Some(ptr) = self.lookup_symbol(name) {
240 ptr
241 } else if linkage == Linkage::Preemptible {
242 0 as *const u8
243 } else {
244 panic!("can't resolve symbol {name}");
245 }
246 }
247 ModuleRelocTarget::LibCall(ref libcall) => {
248 let sym = (self.libcall_names)(*libcall);
249 self.lookup_symbol(&sym)
250 .unwrap_or_else(|| panic!("can't resolve libcall {sym}"))
251 }
252 _ => panic!("invalid name"),
253 }
254 }
255
256 pub fn get_finalized_function(&self, func_id: FuncId) -> *const u8 {
261 let info = &self.compiled_functions[func_id];
262 assert!(
263 !self.functions_to_finalize.iter().any(|x| *x == func_id),
264 "function not yet finalized"
265 );
266 info.as_ref()
267 .expect("function must be compiled before it can be finalized")
268 .ptr
269 }
270
271 pub fn get_finalized_data(&self, data_id: DataId) -> (*const u8, usize) {
276 let info = &self.compiled_data_objects[data_id];
277 assert!(
278 !self.data_objects_to_finalize.iter().any(|x| *x == data_id),
279 "data object not yet finalized"
280 );
281 let compiled = info
282 .as_ref()
283 .expect("data object must be compiled before it can be finalized");
284
285 (compiled.ptr, compiled.size)
286 }
287
288 fn record_function_for_perf(&self, ptr: *mut u8, size: usize, name: &str) {
289 if cfg!(unix) && ::std::env::var_os("PERF_BUILDID_DIR").is_some() {
295 let mut map_file = ::std::fs::OpenOptions::new()
296 .create(true)
297 .append(true)
298 .open(format!("/tmp/perf-{}.map", ::std::process::id()))
299 .unwrap();
300
301 let _ = writeln!(map_file, "{:x} {:x} {}", ptr as usize, size, name);
302 }
303 }
304
305 pub fn finalize_definitions(&mut self) -> ModuleResult<()> {
314 for func in std::mem::take(&mut self.functions_to_finalize) {
315 let decl = self.declarations.get_function_decl(func);
316 assert!(decl.linkage.is_definable());
317 let func = self.compiled_functions[func]
318 .as_ref()
319 .expect("function must be compiled before it can be finalized");
320 func.perform_relocations(|name| self.get_address(name));
321 }
322
323 for data in std::mem::take(&mut self.data_objects_to_finalize) {
324 let decl = self.declarations.get_data_decl(data);
325 assert!(decl.linkage.is_definable());
326 let data = self.compiled_data_objects[data]
327 .as_ref()
328 .expect("data object must be compiled before it can be finalized");
329 data.perform_relocations(|name| self.get_address(name));
330 }
331
332 self.code_ranges
333 .sort_unstable_by_key(|(start, _end, _)| *start);
334
335 let branch_protection = if cfg!(target_arch = "aarch64") && use_bti(&self.isa.isa_flags()) {
337 BranchProtection::BTI
338 } else {
339 BranchProtection::None
340 };
341 self.memory.finalize(branch_protection)?;
342
343 Ok(())
344 }
345
346 pub fn new(builder: JITBuilder) -> Self {
348 assert!(
349 !builder.isa.flags().is_pic(),
350 "cranelift-jit needs is_pic=false"
351 );
352
353 let memory = builder
354 .memory
355 .unwrap_or_else(|| Box::new(SystemMemoryProvider::new()));
356 Self {
357 isa: builder.isa,
358 symbols: RefCell::new(builder.symbols),
359 lookup_symbols: builder.lookup_symbols,
360 libcall_names: builder.libcall_names,
361 memory,
362 declarations: ModuleDeclarations::default(),
363 compiled_functions: SecondaryMap::new(),
364 compiled_data_objects: SecondaryMap::new(),
365 code_ranges: Vec::new(),
366 functions_to_finalize: Vec::new(),
367 data_objects_to_finalize: Vec::new(),
368 }
369 }
370
371 #[cfg(feature = "wasmtime-unwinder")]
375 pub fn lookup_wasmtime_exception_data<'a>(
376 &'a self,
377 pc: usize,
378 ) -> Option<(usize, wasmtime_unwinder::ExceptionTable<'a>)> {
379 let idx = match self
381 .code_ranges
382 .binary_search_by_key(&pc, |(start, _end, _func)| *start)
383 {
384 Ok(exact_start_match) => Some(exact_start_match),
385 Err(least_upper_bound) if least_upper_bound > 0 => {
386 let last_range_before_pc = &self.code_ranges[least_upper_bound - 1];
387 if last_range_before_pc.0 <= pc && pc < last_range_before_pc.1 {
388 Some(least_upper_bound - 1)
389 } else {
390 None
391 }
392 }
393 _ => None,
394 }?;
395
396 let (start, _, func) = self.code_ranges[idx];
397
398 let data = self.compiled_functions[func]
402 .as_ref()
403 .unwrap()
404 .exception_data
405 .as_ref()?;
406 let exception_table = wasmtime_unwinder::ExceptionTable::parse(data).ok()?;
407 Some((start, exception_table))
408 }
409}
410
411impl Module for JITModule {
412 fn isa(&self) -> &dyn TargetIsa {
413 &*self.isa
414 }
415
416 fn declarations(&self) -> &ModuleDeclarations {
417 &self.declarations
418 }
419
420 fn declare_function(
421 &mut self,
422 name: &str,
423 linkage: Linkage,
424 signature: &ir::Signature,
425 ) -> ModuleResult<FuncId> {
426 let (id, _linkage) = self
427 .declarations
428 .declare_function(name, linkage, signature)?;
429 Ok(id)
430 }
431
432 fn declare_anonymous_function(&mut self, signature: &ir::Signature) -> ModuleResult<FuncId> {
433 let id = self.declarations.declare_anonymous_function(signature)?;
434 Ok(id)
435 }
436
437 fn declare_data(
438 &mut self,
439 name: &str,
440 linkage: Linkage,
441 writable: bool,
442 tls: bool,
443 ) -> ModuleResult<DataId> {
444 assert!(!tls, "JIT doesn't yet support TLS");
445 let (id, _linkage) = self
446 .declarations
447 .declare_data(name, linkage, writable, tls)?;
448 Ok(id)
449 }
450
451 fn declare_anonymous_data(&mut self, writable: bool, tls: bool) -> ModuleResult<DataId> {
452 assert!(!tls, "JIT doesn't yet support TLS");
453 let id = self.declarations.declare_anonymous_data(writable, tls)?;
454 Ok(id)
455 }
456
457 fn define_function_with_control_plane(
458 &mut self,
459 id: FuncId,
460 ctx: &mut cranelift_codegen::Context,
461 ctrl_plane: &mut ControlPlane,
462 ) -> ModuleResult<()> {
463 info!("defining function {}: {}", id, ctx.func.display());
464 let decl = self.declarations.get_function_decl(id);
465 if !decl.linkage.is_definable() {
466 return Err(ModuleError::InvalidImportDefinition(
467 decl.linkage_name(id).into_owned(),
468 ));
469 }
470
471 if !self.compiled_functions[id].is_none() {
472 return Err(ModuleError::DuplicateDefinition(
473 decl.linkage_name(id).into_owned(),
474 ));
475 }
476
477 let res = ctx.compile(self.isa(), ctrl_plane)?;
479 let alignment = res.buffer.alignment as u64;
480 let compiled_code = ctx.compiled_code().unwrap();
481
482 let size = compiled_code.code_info().total_size as usize;
483 let align = alignment
484 .max(self.isa.function_alignment().minimum as u64)
485 .max(self.isa.symbol_alignment());
486 let ptr =
487 self.memory
488 .allocate_readexec(size, align)
489 .map_err(|e| ModuleError::Allocation {
490 message: "unable to alloc function",
491 err: e,
492 })?;
493
494 {
495 let mem = unsafe { std::slice::from_raw_parts_mut(ptr, size) };
496 mem.copy_from_slice(compiled_code.code_buffer());
497 }
498
499 let relocs = compiled_code
500 .buffer
501 .relocs()
502 .iter()
503 .map(|reloc| ModuleReloc::from_mach_reloc(reloc, &ctx.func, id))
504 .collect();
505
506 self.record_function_for_perf(ptr, size, &decl.linkage_name(id));
507 self.compiled_functions[id] = Some(CompiledBlob {
508 ptr,
509 size,
510 relocs,
511 #[cfg(feature = "wasmtime-unwinder")]
512 exception_data: None,
513 });
514
515 let range_start = ptr as usize;
516 let range_end = range_start + size;
517 self.code_ranges.push((range_start, range_end, id));
519
520 #[cfg(feature = "wasmtime-unwinder")]
521 {
522 let mut exception_builder = wasmtime_unwinder::ExceptionTableBuilder::default();
523 exception_builder
524 .add_func(0, compiled_code.buffer.call_sites())
525 .map_err(|_| {
526 ModuleError::Compilation(cranelift_codegen::CodegenError::Unsupported(
527 "Invalid exception data".into(),
528 ))
529 })?;
530 self.compiled_functions[id].as_mut().unwrap().exception_data =
531 Some(exception_builder.to_vec());
532 }
533
534 self.functions_to_finalize.push(id);
535
536 Ok(())
537 }
538
539 fn define_function_bytes(
540 &mut self,
541 id: FuncId,
542 alignment: u64,
543 bytes: &[u8],
544 relocs: &[ModuleReloc],
545 ) -> ModuleResult<()> {
546 info!("defining function {} with bytes", id);
547 let decl = self.declarations.get_function_decl(id);
548 if !decl.linkage.is_definable() {
549 return Err(ModuleError::InvalidImportDefinition(
550 decl.linkage_name(id).into_owned(),
551 ));
552 }
553
554 if !self.compiled_functions[id].is_none() {
555 return Err(ModuleError::DuplicateDefinition(
556 decl.linkage_name(id).into_owned(),
557 ));
558 }
559
560 let size = bytes.len();
561 let align = alignment
562 .max(self.isa.function_alignment().minimum as u64)
563 .max(self.isa.symbol_alignment());
564 let ptr =
565 self.memory
566 .allocate_readexec(size, align)
567 .map_err(|e| ModuleError::Allocation {
568 message: "unable to alloc function bytes",
569 err: e,
570 })?;
571
572 unsafe {
573 ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, size);
574 }
575
576 self.record_function_for_perf(ptr, size, &decl.linkage_name(id));
577 self.compiled_functions[id] = Some(CompiledBlob {
578 ptr,
579 size,
580 relocs: relocs.to_owned(),
581 #[cfg(feature = "wasmtime-unwinder")]
582 exception_data: None,
583 });
584
585 self.functions_to_finalize.push(id);
586
587 Ok(())
588 }
589
590 fn define_data(&mut self, id: DataId, data: &DataDescription) -> ModuleResult<()> {
591 let decl = self.declarations.get_data_decl(id);
592 if !decl.linkage.is_definable() {
593 return Err(ModuleError::InvalidImportDefinition(
594 decl.linkage_name(id).into_owned(),
595 ));
596 }
597
598 if !self.compiled_data_objects[id].is_none() {
599 return Err(ModuleError::DuplicateDefinition(
600 decl.linkage_name(id).into_owned(),
601 ));
602 }
603
604 assert!(!decl.tls, "JIT doesn't yet support TLS");
605
606 let &DataDescription {
607 ref init,
608 function_decls: _,
609 data_decls: _,
610 function_relocs: _,
611 data_relocs: _,
612 custom_segment_section: _,
613 align,
614 } = data;
615
616 let alloc_size = std::cmp::max(init.size(), 1);
621
622 let ptr = if decl.writable {
623 self.memory
624 .allocate_readwrite(alloc_size, align.unwrap_or(WRITABLE_DATA_ALIGNMENT))
625 .map_err(|e| ModuleError::Allocation {
626 message: "unable to alloc writable data",
627 err: e,
628 })?
629 } else {
630 self.memory
631 .allocate_readonly(alloc_size, align.unwrap_or(READONLY_DATA_ALIGNMENT))
632 .map_err(|e| ModuleError::Allocation {
633 message: "unable to alloc readonly data",
634 err: e,
635 })?
636 };
637
638 if ptr.is_null() {
639 std::alloc::handle_alloc_error(
641 std::alloc::Layout::from_size_align(
642 alloc_size,
643 align.unwrap_or(READONLY_DATA_ALIGNMENT).try_into().unwrap(),
644 )
645 .unwrap(),
646 );
647 }
648
649 match *init {
650 Init::Uninitialized => {
651 panic!("data is not initialized yet");
652 }
653 Init::Zeros { size } => {
654 unsafe { ptr::write_bytes(ptr, 0, size) };
655 }
656 Init::Bytes { ref contents } => {
657 let src = contents.as_ptr();
658 unsafe { ptr::copy_nonoverlapping(src, ptr, contents.len()) };
659 }
660 }
661
662 let pointer_reloc = match self.isa.triple().pointer_width().unwrap() {
663 PointerWidth::U16 => panic!(),
664 PointerWidth::U32 => Reloc::Abs4,
665 PointerWidth::U64 => Reloc::Abs8,
666 };
667 let relocs = data.all_relocs(pointer_reloc).collect::<Vec<_>>();
668
669 self.compiled_data_objects[id] = Some(CompiledBlob {
670 ptr,
671 size: init.size(),
672 relocs,
673 #[cfg(feature = "wasmtime-unwinder")]
674 exception_data: None,
675 });
676 self.data_objects_to_finalize.push(id);
677
678 Ok(())
679 }
680
681 fn get_name(&self, name: &str) -> Option<cranelift_module::FuncOrDataId> {
682 self.declarations().get_name(name)
683 }
684
685 fn target_config(&self) -> cranelift_codegen::isa::TargetFrontendConfig {
686 self.isa().frontend_config()
687 }
688
689 fn make_context(&self) -> cranelift_codegen::Context {
690 let mut ctx = cranelift_codegen::Context::new();
691 ctx.func.signature.call_conv = self.isa().default_call_conv();
692 ctx
693 }
694
695 fn clear_context(&self, ctx: &mut cranelift_codegen::Context) {
696 ctx.clear();
697 ctx.func.signature.call_conv = self.isa().default_call_conv();
698 }
699
700 fn make_signature(&self) -> ir::Signature {
701 ir::Signature::new(self.isa().default_call_conv())
702 }
703
704 fn clear_signature(&self, sig: &mut ir::Signature) {
705 sig.clear(self.isa().default_call_conv());
706 }
707}
708
709#[cfg(not(windows))]
710fn lookup_with_dlsym(name: &str) -> Option<*const u8> {
711 let c_str = CString::new(name).unwrap();
712 let c_str_ptr = c_str.as_ptr();
713 let sym = unsafe { libc::dlsym(libc::RTLD_DEFAULT, c_str_ptr) };
714 if sym.is_null() {
715 None
716 } else {
717 Some(sym as *const u8)
718 }
719}
720
721#[cfg(windows)]
722fn lookup_with_dlsym(name: &str) -> Option<*const u8> {
723 use std::os::windows::io::RawHandle;
724 use windows_sys::Win32::Foundation::HMODULE;
725 use windows_sys::Win32::System::LibraryLoader;
726
727 const UCRTBASE: &[u8] = b"ucrtbase.dll\0";
728
729 let c_str = CString::new(name).unwrap();
730 let c_str_ptr = c_str.as_ptr();
731
732 unsafe {
733 let handles = [
734 ptr::null_mut(),
736 LibraryLoader::GetModuleHandleA(UCRTBASE.as_ptr()) as RawHandle,
738 ];
739
740 for handle in &handles {
741 let addr = LibraryLoader::GetProcAddress(*handle as HMODULE, c_str_ptr.cast());
742 match addr {
743 None => continue,
744 Some(addr) => return Some(addr as *const u8),
745 }
746 }
747
748 None
749 }
750}
751
752fn use_bti(isa_flags: &Vec<settings::Value>) -> bool {
753 isa_flags
754 .iter()
755 .find(|&f| f.name == "use_bti")
756 .map_or(false, |f| f.as_bool().unwrap_or(false))
757}