1use anyhow::{Context as _, Result, anyhow};
3use core::mem;
4use cranelift::prelude::Imm64;
5use cranelift_codegen::cursor::{Cursor, FuncCursor};
6use cranelift_codegen::data_value::DataValue;
7use cranelift_codegen::ir::{
8 ExternalName, Function, InstBuilder, InstructionData, LibCall, Opcode, Signature,
9 UserExternalName, UserFuncName,
10};
11use cranelift_codegen::isa::{OwnedTargetIsa, TargetIsa};
12use cranelift_codegen::{CodegenError, Context, ir, settings};
13use cranelift_control::ControlPlane;
14use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
15use cranelift_jit::{JITBuilder, JITModule};
16use cranelift_module::{FuncId, Linkage, Module, ModuleError};
17use cranelift_native::builder_with_options;
18use cranelift_reader::TestFile;
19use pulley_interpreter::interp as pulley;
20use std::cell::Cell;
21use std::cmp::max;
22use std::collections::hash_map::Entry;
23use std::collections::{HashMap, HashSet};
24use std::ptr::NonNull;
25use target_lexicon::Architecture;
26use thiserror::Error;
27
28const TESTFILE_NAMESPACE: u32 = 0;
29
30#[derive(Debug)]
32struct DefinedFunction {
33 new_name: UserExternalName,
42
43 signature: ir::Signature,
45
46 func_id: FuncId,
48}
49
50pub struct TestFileCompiler {
80 module: JITModule,
81 ctx: Context,
82
83 defined_functions: HashMap<UserFuncName, DefinedFunction>,
87
88 trampolines: HashMap<Signature, UserFuncName>,
94}
95
96impl TestFileCompiler {
97 pub fn new(isa: OwnedTargetIsa) -> Self {
101 let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
102 builder.symbol_lookup_fn(Box::new(lookup_libcall));
103
104 #[cfg(unix)]
109 {
110 unsafe extern "C" {
111 safe fn cosf(f: f32) -> f32;
112 }
113 let f = std::hint::black_box(1.2_f32);
114 assert_eq!(f.cos(), cosf(f));
115 }
116
117 let module = JITModule::new(builder);
118 let ctx = module.make_context();
119
120 Self {
121 module,
122 ctx,
123 defined_functions: HashMap::new(),
124 trampolines: HashMap::new(),
125 }
126 }
127
128 pub fn with_host_isa(flags: settings::Flags) -> Result<Self> {
130 let builder = builder_with_options(true)
131 .map_err(anyhow::Error::msg)
132 .context("Unable to build a TargetIsa for the current host")?;
133 let isa = builder.finish(flags)?;
134 Ok(Self::new(isa))
135 }
136
137 pub fn with_default_host_isa() -> Result<Self> {
140 let flags = settings::Flags::new(settings::builder());
141 Self::with_host_isa(flags)
142 }
143
144 pub fn add_functions(
147 &mut self,
148 functions: &[Function],
149 ctrl_planes: Vec<ControlPlane>,
150 ) -> Result<()> {
151 for func in functions {
153 self.declare_function(func)?;
154 }
155
156 let ctrl_planes = ctrl_planes
157 .into_iter()
158 .chain(std::iter::repeat(ControlPlane::default()));
159
160 for (func, ref mut ctrl_plane) in functions.iter().zip(ctrl_planes) {
162 self.define_function(func.clone(), ctrl_plane)?;
163 self.create_trampoline_for_function(func, ctrl_plane)?;
164 }
165
166 Ok(())
167 }
168
169 pub fn add_testfile(&mut self, testfile: &TestFile) -> Result<()> {
172 let functions = testfile
173 .functions
174 .iter()
175 .map(|(f, _)| f)
176 .cloned()
177 .collect::<Vec<_>>();
178
179 self.add_functions(&functions[..], Vec::new())?;
180 Ok(())
181 }
182
183 pub fn declare_function(&mut self, func: &Function) -> Result<()> {
185 let next_id = self.defined_functions.len() as u32;
186 match self.defined_functions.entry(func.name.clone()) {
187 Entry::Occupied(_) => {
188 anyhow::bail!("Duplicate function with name {} found!", &func.name)
189 }
190 Entry::Vacant(v) => {
191 let name = func.name.to_string();
192 let func_id =
193 self.module
194 .declare_function(&name, Linkage::Local, &func.signature)?;
195
196 v.insert(DefinedFunction {
197 new_name: UserExternalName::new(TESTFILE_NAMESPACE, next_id),
198 signature: func.signature.clone(),
199 func_id,
200 });
201 }
202 };
203
204 Ok(())
205 }
206
207 fn apply_func_rename(
212 &self,
213 mut func: Function,
214 defined_func: &DefinedFunction,
215 ) -> Result<Function> {
216 let func_original_name = func.name;
218 func.name = UserFuncName::User(defined_func.new_name.clone());
219
220 let mut redefines = Vec::with_capacity(func.dfg.ext_funcs.len());
223 for (ext_ref, ext_func) in &func.dfg.ext_funcs {
224 let old_name = match &ext_func.name {
225 ExternalName::TestCase(tc) => UserFuncName::Testcase(tc.clone()),
226 ExternalName::User(username) => {
227 UserFuncName::User(func.params.user_named_funcs()[*username].clone())
228 }
229 _ => continue,
231 };
232
233 let target_df = self.defined_functions.get(&old_name).ok_or(anyhow!(
234 "Undeclared function {} is referenced by {}!",
235 &old_name,
236 &func_original_name
237 ))?;
238
239 redefines.push((ext_ref, target_df.new_name.clone()));
240 }
241
242 for (ext_ref, new_name) in redefines.into_iter() {
244 let new_name_ref = func.params.ensure_user_func_name(new_name);
246
247 func.dfg.ext_funcs[ext_ref].name = ExternalName::User(new_name_ref);
249 }
250
251 Ok(func)
252 }
253
254 pub fn define_function(
256 &mut self,
257 mut func: Function,
258 ctrl_plane: &mut ControlPlane,
259 ) -> Result<()> {
260 Self::replace_hostcall_references(&mut func);
261
262 let defined_func = self
263 .defined_functions
264 .get(&func.name)
265 .ok_or(anyhow!("Undeclared function {} found!", &func.name))?;
266
267 self.ctx.func = self.apply_func_rename(func, defined_func)?;
268 self.module.define_function_with_control_plane(
269 defined_func.func_id,
270 &mut self.ctx,
271 ctrl_plane,
272 )?;
273 self.module.clear_context(&mut self.ctx);
274 Ok(())
275 }
276
277 fn replace_hostcall_references(func: &mut Function) {
278 let mut funcrefs_to_remove = HashSet::new();
284 let mut cursor = FuncCursor::new(func);
285 while let Some(_block) = cursor.next_block() {
286 while let Some(inst) = cursor.next_inst() {
287 match &cursor.func.dfg.insts[inst] {
288 InstructionData::FuncAddr {
289 opcode: Opcode::FuncAddr,
290 func_ref,
291 } => {
292 let ext_func = &cursor.func.dfg.ext_funcs[*func_ref];
293 let hostcall_addr = match &ext_func.name {
294 ExternalName::TestCase(tc) if tc.raw() == b"__cranelift_throw" => {
295 Some(__cranelift_throw as usize)
296 }
297 _ => None,
298 };
299
300 if let Some(addr) = hostcall_addr {
301 funcrefs_to_remove.insert(*func_ref);
302 cursor.func.dfg.insts[inst] = InstructionData::UnaryImm {
303 opcode: Opcode::Iconst,
304 imm: Imm64::new(addr as i64),
305 };
306 }
307 }
308 _ => {}
309 }
310 }
311 }
312
313 for to_remove in funcrefs_to_remove {
314 func.dfg.ext_funcs[to_remove].name = ExternalName::LibCall(LibCall::Probestack);
315 }
316 }
317
318 pub fn create_trampoline_for_function(
320 &mut self,
321 func: &Function,
322 ctrl_plane: &mut ControlPlane,
323 ) -> Result<()> {
324 if !self.defined_functions.contains_key(&func.name) {
325 anyhow::bail!("Undeclared function {} found!", &func.name);
326 }
327
328 if self.trampolines.contains_key(&func.signature) {
330 return Ok(());
331 }
332
333 let name = UserFuncName::user(TESTFILE_NAMESPACE, self.defined_functions.len() as u32);
335 let trampoline = make_trampoline(name.clone(), &func.signature, self.module.isa());
336
337 self.declare_function(&trampoline)?;
338 self.define_function(trampoline, ctrl_plane)?;
339
340 self.trampolines.insert(func.signature.clone(), name);
341
342 Ok(())
343 }
344
345 pub fn compile(mut self) -> Result<CompiledTestFile, CompilationError> {
347 self.module.finalize_definitions()?;
351
352 Ok(CompiledTestFile {
353 module: Some(self.module),
354 defined_functions: self.defined_functions,
355 trampolines: self.trampolines,
356 })
357 }
358}
359
360pub struct CompiledTestFile {
362 module: Option<JITModule>,
365
366 defined_functions: HashMap<UserFuncName, DefinedFunction>,
369
370 trampolines: HashMap<Signature, UserFuncName>,
373}
374
375impl CompiledTestFile {
376 pub fn get_trampoline(&self, func: &Function) -> Option<Trampoline<'_>> {
380 let defined_func = self.defined_functions.get(&func.name)?;
381 let trampoline_id = self
382 .trampolines
383 .get(&func.signature)
384 .and_then(|name| self.defined_functions.get(name))
385 .map(|df| df.func_id)?;
386 Some(Trampoline {
387 module: self.module.as_ref()?,
388 func_id: defined_func.func_id,
389 func_signature: &defined_func.signature,
390 trampoline_id,
391 })
392 }
393}
394
395impl Drop for CompiledTestFile {
396 fn drop(&mut self) {
397 unsafe { self.module.take().unwrap().free_memory() }
400 }
401}
402
403std::thread_local! {
404 pub static COMPILED_TEST_FILE: Cell<*const CompiledTestFile> = Cell::new(std::ptr::null());
408}
409
410pub struct Trampoline<'a> {
412 module: &'a JITModule,
413 func_id: FuncId,
414 func_signature: &'a Signature,
415 trampoline_id: FuncId,
416}
417
418impl<'a> Trampoline<'a> {
419 pub fn call(&self, compiled: &CompiledTestFile, arguments: &[DataValue]) -> Vec<DataValue> {
421 let mut values = UnboxedValues::make_arguments(arguments, &self.func_signature);
422 let arguments_address = values.as_mut_ptr();
423
424 let function_ptr = self.module.get_finalized_function(self.func_id);
425 let trampoline_ptr = self.module.get_finalized_function(self.trampoline_id);
426
427 COMPILED_TEST_FILE.set(compiled as *const _);
428 unsafe {
429 self.call_raw(trampoline_ptr, function_ptr, arguments_address);
430 }
431 COMPILED_TEST_FILE.set(std::ptr::null());
432
433 values.collect_returns(&self.func_signature)
434 }
435
436 unsafe fn call_raw(
437 &self,
438 trampoline_ptr: *const u8,
439 function_ptr: *const u8,
440 arguments_address: *mut u128,
441 ) {
442 match self.module.isa().triple().architecture {
443 Architecture::Pulley32
446 | Architecture::Pulley64
447 | Architecture::Pulley32be
448 | Architecture::Pulley64be => {
449 let mut state = pulley::Vm::new();
450 unsafe {
451 state.call(
452 NonNull::new(trampoline_ptr.cast_mut()).unwrap(),
453 &[
454 pulley::XRegVal::new_ptr(function_ptr.cast_mut()).into(),
455 pulley::XRegVal::new_ptr(arguments_address).into(),
456 ],
457 [],
458 );
459 }
460 }
461
462 _ => {
464 let callable_trampoline: fn(*const u8, *mut u128) -> () =
465 unsafe { mem::transmute(trampoline_ptr) };
466 callable_trampoline(function_ptr, arguments_address);
467 }
468 }
469 }
470}
471
472#[derive(Error, Debug)]
474pub enum CompilationError {
475 #[error("Cranelift codegen error")]
477 CodegenError(#[from] CodegenError),
478 #[error("Module error")]
480 ModuleError(#[from] ModuleError),
481 #[error("Memory mapping error")]
483 IoError(#[from] std::io::Error),
484}
485
486struct UnboxedValues(Vec<u128>);
489
490impl UnboxedValues {
491 const SLOT_SIZE: usize = 16;
496
497 pub fn make_arguments(arguments: &[DataValue], signature: &ir::Signature) -> Self {
500 assert_eq!(arguments.len(), signature.params.len());
501 let mut values_vec = vec![0; max(signature.params.len(), signature.returns.len())];
502
503 for ((arg, slot), param) in arguments.iter().zip(&mut values_vec).zip(&signature.params) {
505 assert!(
506 arg.ty() == param.value_type || arg.is_vector(),
507 "argument type mismatch: {} != {}",
508 arg.ty(),
509 param.value_type
510 );
511 unsafe {
512 arg.write_value_to(slot);
513 }
514 }
515
516 Self(values_vec)
517 }
518
519 pub fn as_mut_ptr(&mut self) -> *mut u128 {
521 self.0.as_mut_ptr()
522 }
523
524 pub fn collect_returns(&self, signature: &ir::Signature) -> Vec<DataValue> {
527 assert!(self.0.len() >= signature.returns.len());
528 let mut returns = Vec::with_capacity(signature.returns.len());
529
530 for (slot, param) in self.0.iter().zip(&signature.returns) {
532 let value = unsafe { DataValue::read_value_from(slot, param.value_type) };
533 returns.push(value);
534 }
535
536 returns
537 }
538}
539
540fn make_trampoline(name: UserFuncName, signature: &ir::Signature, isa: &dyn TargetIsa) -> Function {
546 let pointer_type = isa.pointer_type();
548 let mut wrapper_sig = ir::Signature::new(isa.frontend_config().default_call_conv);
549 wrapper_sig.params.push(ir::AbiParam::new(pointer_type)); wrapper_sig.params.push(ir::AbiParam::new(pointer_type)); let mut func = ir::Function::with_name_signature(name, wrapper_sig);
553
554 let mut builder_context = FunctionBuilderContext::new();
556 let mut builder = FunctionBuilder::new(&mut func, &mut builder_context);
557 let block0 = builder.create_block();
558 builder.append_block_params_for_function_params(block0);
559 builder.switch_to_block(block0);
560 builder.seal_block(block0);
561
562 let (callee_value, values_vec_ptr_val) = {
564 let params = builder.func.dfg.block_params(block0);
565 (params[0], params[1])
566 };
567
568 let callee_args = signature
570 .params
571 .iter()
572 .enumerate()
573 .map(|(i, param)| {
574 let mut flags = ir::MemFlags::trusted();
576 if param.value_type.is_vector() {
577 flags.set_endianness(ir::Endianness::Little);
578 }
579
580 builder.ins().load(
582 param.value_type,
583 flags,
584 values_vec_ptr_val,
585 (i * UnboxedValues::SLOT_SIZE) as i32,
586 )
587 })
588 .collect::<Vec<_>>();
589
590 let new_sig = builder.import_signature(signature.clone());
592 let call = builder
593 .ins()
594 .call_indirect(new_sig, callee_value, &callee_args);
595
596 let results = builder.func.dfg.inst_results(call).to_vec();
598 for ((i, value), param) in results.iter().enumerate().zip(&signature.returns) {
599 let mut flags = ir::MemFlags::trusted();
601 if param.value_type.is_vector() {
602 flags.set_endianness(ir::Endianness::Little);
603 }
604 builder.ins().store(
606 flags,
607 *value,
608 values_vec_ptr_val,
609 (i * UnboxedValues::SLOT_SIZE) as i32,
610 );
611 }
612
613 builder.ins().return_(&[]);
614 builder.finalize();
615
616 func
617}
618
619#[cfg(any(
626 target_arch = "x86_64",
627 target_arch = "aarch64",
628 target_arch = "s390x",
629 target_arch = "riscv64"
630))]
631extern "C-unwind" fn __cranelift_throw(
632 entry_fp: usize,
633 exit_fp: usize,
634 exit_pc: usize,
635 tag: u32,
636 payload1: usize,
637 payload2: usize,
638) -> ! {
639 let compiled_test_file = unsafe { &*COMPILED_TEST_FILE.get() };
640 let unwind_host = wasmtime_unwinder::UnwindHost;
641 let frame_handler = |frame: &wasmtime_unwinder::Frame| -> Option<(usize, usize)> {
642 let (base, table) = compiled_test_file
643 .module
644 .as_ref()
645 .unwrap()
646 .lookup_wasmtime_exception_data(frame.pc())?;
647 let relative_pc = u32::try_from(
648 frame
649 .pc()
650 .checked_sub(base)
651 .expect("module lookup did not return a module base below the PC"),
652 )
653 .expect("module larger than 4GiB");
654
655 table
656 .lookup_pc_tag(relative_pc, tag)
657 .map(|(frame_offset, handler)| {
658 let handler_sp = frame
659 .fp()
660 .wrapping_sub(usize::try_from(frame_offset).unwrap());
661 let handler_pc = base
662 .checked_add(usize::try_from(handler).unwrap())
663 .expect("Handler address computation overflowed");
664 (handler_pc, handler_sp)
665 })
666 };
667 unsafe {
668 match wasmtime_unwinder::Handler::find(
669 &unwind_host,
670 frame_handler,
671 exit_pc,
672 exit_fp,
673 entry_fp,
674 ) {
675 Some(handler) => handler.resume_tailcc(payload1, payload2),
676 None => {
677 panic!("Expected a handler to exit for throw of tag {tag} at pc {exit_pc:x}");
678 }
679 }
680 }
681}
682
683#[cfg(not(any(
684 target_arch = "x86_64",
685 target_arch = "aarch64",
686 target_arch = "s390x",
687 target_arch = "riscv64"
688)))]
689extern "C-unwind" fn __cranelift_throw(
690 _entry_fp: usize,
691 _exit_fp: usize,
692 _exit_pc: usize,
693 _tag: u32,
694 _payload1: usize,
695 _payload2: usize,
696) -> ! {
697 panic!("Throw not implemented on platforms without native backends.");
698}
699
700fn lookup_libcall(name: &str) -> Option<*const u8> {
708 match name {
709 "ceil" => {
710 extern "C" fn ceil(a: f64) -> f64 {
711 a.ceil()
712 }
713 Some(ceil as *const u8)
714 }
715 "ceilf" => {
716 extern "C" fn ceilf(a: f32) -> f32 {
717 a.ceil()
718 }
719 Some(ceilf as *const u8)
720 }
721 "trunc" => {
722 extern "C" fn trunc(a: f64) -> f64 {
723 a.trunc()
724 }
725 Some(trunc as *const u8)
726 }
727 "truncf" => {
728 extern "C" fn truncf(a: f32) -> f32 {
729 a.trunc()
730 }
731 Some(truncf as *const u8)
732 }
733 "floor" => {
734 extern "C" fn floor(a: f64) -> f64 {
735 a.floor()
736 }
737 Some(floor as *const u8)
738 }
739 "floorf" => {
740 extern "C" fn floorf(a: f32) -> f32 {
741 a.floor()
742 }
743 Some(floorf as *const u8)
744 }
745 "nearbyint" => {
746 extern "C" fn nearbyint(a: f64) -> f64 {
747 a.round_ties_even()
748 }
749 Some(nearbyint as *const u8)
750 }
751 "nearbyintf" => {
752 extern "C" fn nearbyintf(a: f32) -> f32 {
753 a.round_ties_even()
754 }
755 Some(nearbyintf as *const u8)
756 }
757 "fma" => {
758 extern "C" fn fma(a: f64, b: f64, c: f64) -> f64 {
762 #[cfg(all(target_os = "windows", target_env = "gnu"))]
763 return libm::fma(a, b, c);
764 #[cfg(not(all(target_os = "windows", target_env = "gnu")))]
765 return a.mul_add(b, c);
766 }
767 Some(fma as *const u8)
768 }
769 "fmaf" => {
770 extern "C" fn fmaf(a: f32, b: f32, c: f32) -> f32 {
771 #[cfg(all(target_os = "windows", target_env = "gnu"))]
772 return libm::fmaf(a, b, c);
773 #[cfg(not(all(target_os = "windows", target_env = "gnu")))]
774 return a.mul_add(b, c);
775 }
776 Some(fmaf as *const u8)
777 }
778
779 #[cfg(target_arch = "x86_64")]
780 "__cranelift_x86_pshufb" => Some(__cranelift_x86_pshufb as *const u8),
781
782 _ => panic!("unknown libcall {name}"),
783 }
784}
785
786#[cfg(target_arch = "x86_64")]
787use std::arch::x86_64::__m128i;
788#[cfg(target_arch = "x86_64")]
789#[expect(
790 improper_ctypes_definitions,
791 reason = "manually verified to work for now"
792)]
793extern "C" fn __cranelift_x86_pshufb(a: __m128i, b: __m128i) -> __m128i {
794 union U {
795 reg: __m128i,
796 mem: [u8; 16],
797 }
798
799 unsafe {
800 let a = U { reg: a }.mem;
801 let b = U { reg: b }.mem;
802
803 let select = |arr: &[u8; 16], byte: u8| {
804 if byte & 0x80 != 0 {
805 0x00
806 } else {
807 arr[(byte & 0xf) as usize]
808 }
809 };
810
811 U {
812 mem: [
813 select(&a, b[0]),
814 select(&a, b[1]),
815 select(&a, b[2]),
816 select(&a, b[3]),
817 select(&a, b[4]),
818 select(&a, b[5]),
819 select(&a, b[6]),
820 select(&a, b[7]),
821 select(&a, b[8]),
822 select(&a, b[9]),
823 select(&a, b[10]),
824 select(&a, b[11]),
825 select(&a, b[12]),
826 select(&a, b[13]),
827 select(&a, b[14]),
828 select(&a, b[15]),
829 ],
830 }
831 .reg
832 }
833}
834
835#[cfg(test)]
836mod test {
837 use super::*;
838 use cranelift_reader::{ParseOptions, parse_functions, parse_test};
839
840 fn parse(code: &str) -> Function {
841 parse_functions(code).unwrap().into_iter().nth(0).unwrap()
842 }
843
844 #[test]
845 fn nop() {
846 if cranelift_native::builder().is_err() {
848 return;
849 }
850 let code = String::from(
851 "
852 test run
853 function %test() -> i8 {
854 block0:
855 nop
856 v1 = iconst.i8 -1
857 return v1
858 }",
859 );
860 let ctrl_plane = &mut ControlPlane::default();
861
862 let test_file = parse_test(code.as_str(), ParseOptions::default()).unwrap();
864 assert_eq!(1, test_file.functions.len());
865 let function = test_file.functions[0].0.clone();
866
867 let mut compiler = TestFileCompiler::with_default_host_isa().unwrap();
869 compiler.declare_function(&function).unwrap();
870 compiler
871 .define_function(function.clone(), ctrl_plane)
872 .unwrap();
873 compiler
874 .create_trampoline_for_function(&function, ctrl_plane)
875 .unwrap();
876 let compiled = compiler.compile().unwrap();
877 let trampoline = compiled.get_trampoline(&function).unwrap();
878 let returned = trampoline.call(&compiled, &[]);
879 assert_eq!(returned, vec![DataValue::I8(-1)])
880 }
881
882 #[test]
883 fn trampolines() {
884 if cranelift_native::builder().is_err() {
886 return;
887 }
888 let function = parse(
889 "
890 function %test(f32, i8, i64x2, i8) -> f32x4, i64 {
891 block0(v0: f32, v1: i8, v2: i64x2, v3: i8):
892 v4 = vconst.f32x4 [0x0.1 0x0.2 0x0.3 0x0.4]
893 v5 = iconst.i64 -1
894 return v4, v5
895 }",
896 );
897
898 let compiler = TestFileCompiler::with_default_host_isa().unwrap();
899 let trampoline = make_trampoline(
900 UserFuncName::user(0, 0),
901 &function.signature,
902 compiler.module.isa(),
903 );
904 println!("{trampoline}");
905 assert!(format!("{trampoline}").ends_with(
906 "sig0 = (f32, i8, i64x2, i8) -> f32x4, i64 fast
907
908block0(v0: i64, v1: i64):
909 v2 = load.f32 notrap aligned v1
910 v3 = load.i8 notrap aligned v1+16
911 v4 = load.i64x2 notrap aligned little v1+32
912 v5 = load.i8 notrap aligned v1+48
913 v6, v7 = call_indirect sig0, v0(v2, v3, v4, v5)
914 store notrap aligned little v6, v1
915 store notrap aligned v7, v1+16
916 return
917}
918"
919 ));
920 }
921}