Skip to main content

cranelift_codegen/
nan_canonicalization.rs

1//! A NaN-canonicalizing rewriting pass. Patch floating point arithmetic
2//! instructions that may return a NaN result with a sequence of operations
3//! that will replace nondeterministic NaN's with a single canonical NaN value.
4
5use crate::cursor::{Cursor, FuncCursor};
6use crate::ir::condcodes::FloatCC;
7use crate::ir::immediates::{Ieee16, Ieee32, Ieee64, Ieee128};
8use crate::ir::types::{self};
9use crate::ir::{Function, Inst, InstBuilder, InstructionData, Opcode, Value};
10use crate::opts::MemFlags;
11use crate::timing;
12
13/// Perform the NaN canonicalization pass.
14pub fn do_nan_canonicalization(func: &mut Function, has_vector_support: bool) {
15    let _tt = timing::canonicalize_nans();
16    let mut pos = FuncCursor::new(func);
17    while let Some(_block) = pos.next_block() {
18        while let Some(inst) = pos.next_inst() {
19            if is_fp_arith(&mut pos, inst) {
20                add_nan_canon_seq(&mut pos, inst, has_vector_support);
21            }
22        }
23    }
24}
25
26/// Returns true/false based on whether the instruction is a floating-point
27/// arithmetic operation. This ignores operations like `fneg`, `fabs`, or
28/// `fcopysign` that only operate on the sign bit of a floating point value.
29fn is_fp_arith(pos: &mut FuncCursor, inst: Inst) -> bool {
30    match pos.func.dfg.insts[inst] {
31        InstructionData::Unary { opcode, .. } => {
32            opcode == Opcode::Ceil
33                || opcode == Opcode::Floor
34                || opcode == Opcode::Nearest
35                || opcode == Opcode::Sqrt
36                || opcode == Opcode::Trunc
37                || opcode == Opcode::Fdemote
38                || opcode == Opcode::Fpromote
39                || opcode == Opcode::FvpromoteLow
40                || opcode == Opcode::Fvdemote
41        }
42        InstructionData::Binary { opcode, .. } => {
43            opcode == Opcode::Fadd
44                || opcode == Opcode::Fdiv
45                || opcode == Opcode::Fmax
46                || opcode == Opcode::Fmin
47                || opcode == Opcode::Fmul
48                || opcode == Opcode::Fsub
49        }
50        InstructionData::Ternary { opcode, .. } => opcode == Opcode::Fma,
51        _ => false,
52    }
53}
54
55/// Append a sequence of canonicalizing instructions after the given instruction.
56fn add_nan_canon_seq(pos: &mut FuncCursor, inst: Inst, has_vector_support: bool) {
57    // Select the instruction result, result type. Replace the instruction
58    // result and step forward before inserting the canonicalization sequence.
59    let val = pos.func.dfg.first_result(inst);
60    let val_type = pos.func.dfg.value_type(val);
61    let new_res = pos.func.dfg.replace_result(val, val_type);
62    let _next_inst = pos.next_inst().expect("block missing terminator!");
63
64    // Insert a comparison instruction, to check if `inst_res` is NaN (comparing
65    // against NaN is always unordered). Select the canonical NaN value if `val`
66    // is NaN, assign the result to `inst`.
67    let comparison = FloatCC::Unordered;
68
69    let vectorized_scalar_select = |pos: &mut FuncCursor, canon_nan: Value, ty: types::Type| {
70        let canon_nan = pos.ins().scalar_to_vector(ty, canon_nan);
71        let new_res = pos.ins().scalar_to_vector(ty, new_res);
72        let is_nan = pos.ins().fcmp(comparison, new_res, new_res);
73        let is_nan = pos.ins().bitcast(ty, MemFlags::new(), is_nan);
74        let simd_result = pos.ins().bitselect(is_nan, canon_nan, new_res);
75        pos.ins().with_result(val).extractlane(simd_result, 0);
76    };
77    let scalar_select = |pos: &mut FuncCursor, canon_nan: Value| {
78        let is_nan = pos.ins().fcmp(comparison, new_res, new_res);
79        pos.ins()
80            .with_result(val)
81            .select(is_nan, canon_nan, new_res);
82    };
83
84    let vector_select = |pos: &mut FuncCursor, canon_nan: Value| {
85        let is_nan = pos.ins().fcmp(comparison, new_res, new_res);
86        let is_nan = pos.ins().bitcast(val_type, MemFlags::new(), is_nan);
87        pos.ins()
88            .with_result(val)
89            .bitselect(is_nan, canon_nan, new_res);
90    };
91
92    match val_type {
93        types::F16 => {
94            let canon_nan = pos.ins().f16const(Ieee16::NAN);
95            scalar_select(pos, canon_nan);
96        }
97        types::F32 => {
98            let canon_nan = pos.ins().f32const(Ieee32::NAN);
99            if has_vector_support {
100                vectorized_scalar_select(pos, canon_nan, types::F32X4);
101            } else {
102                scalar_select(pos, canon_nan);
103            }
104        }
105        types::F64 => {
106            let canon_nan = pos.ins().f64const(Ieee64::NAN);
107            if has_vector_support {
108                vectorized_scalar_select(pos, canon_nan, types::F64X2);
109            } else {
110                scalar_select(pos, canon_nan);
111            }
112        }
113        types::F32X4 => {
114            let canon_nan = pos.ins().f32const(Ieee32::NAN);
115            let canon_nan = pos.ins().splat(types::F32X4, canon_nan);
116            vector_select(pos, canon_nan);
117        }
118        types::F64X2 => {
119            let canon_nan = pos.ins().f64const(Ieee64::NAN);
120            let canon_nan = pos.ins().splat(types::F64X2, canon_nan);
121            vector_select(pos, canon_nan);
122        }
123        types::F128 => {
124            let nan_const = pos.func.dfg.constants.insert(Ieee128::NAN.into());
125            let canon_nan = pos.ins().f128const(nan_const);
126            scalar_select(pos, canon_nan);
127        }
128        _ => {
129            // Panic if the type given was not an IEEE floating point type.
130            panic!("Could not canonicalize NaN: Unexpected result type found.");
131        }
132    }
133
134    pos.prev_inst(); // Step backwards so the pass does not skip instructions.
135}