Skip to content

Commit c4d6b0b

Browse files
committed
Implement simd_fma and simd_relaxed_fma in const-eval
1 parent 58b4453 commit c4d6b0b

File tree

5 files changed

+88
-90
lines changed

5 files changed

+88
-90
lines changed

compiler/rustc_const_eval/src/interpret/intrinsics.rs

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ use super::{
2525
};
2626
use crate::fluent_generated as fluent;
2727

28+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
29+
enum MulAddType {
30+
/// Used with `fma` and `simd_fma`, always uses fused-multiply-add
31+
Fused,
32+
/// Used with `fmuladd` and `simd_relaxed_fma`, nondeterministically determines whether to use
33+
/// fma or simple multiply-add
34+
Nondeterministic,
35+
}
36+
2837
/// Directly returns an `Allocation` containing an absolute path representation of the given type.
2938
pub(crate) fn alloc_type_name<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> (AllocId, u64) {
3039
let path = crate::util::type_name(tcx, ty);
@@ -630,14 +639,22 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
630639
dest,
631640
rustc_apfloat::Round::NearestTiesToEven,
632641
)?,
633-
sym::fmaf16 => self.fma_intrinsic::<Half>(args, dest)?,
634-
sym::fmaf32 => self.fma_intrinsic::<Single>(args, dest)?,
635-
sym::fmaf64 => self.fma_intrinsic::<Double>(args, dest)?,
636-
sym::fmaf128 => self.fma_intrinsic::<Quad>(args, dest)?,
637-
sym::fmuladdf16 => self.float_muladd_intrinsic::<Half>(args, dest)?,
638-
sym::fmuladdf32 => self.float_muladd_intrinsic::<Single>(args, dest)?,
639-
sym::fmuladdf64 => self.float_muladd_intrinsic::<Double>(args, dest)?,
640-
sym::fmuladdf128 => self.float_muladd_intrinsic::<Quad>(args, dest)?,
642+
sym::fmaf16 => self.float_muladd_intrinsic::<Half>(args, dest, MulAddType::Fused)?,
643+
sym::fmaf32 => self.float_muladd_intrinsic::<Single>(args, dest, MulAddType::Fused)?,
644+
sym::fmaf64 => self.float_muladd_intrinsic::<Double>(args, dest, MulAddType::Fused)?,
645+
sym::fmaf128 => self.float_muladd_intrinsic::<Quad>(args, dest, MulAddType::Fused)?,
646+
sym::fmuladdf16 => {
647+
self.float_muladd_intrinsic::<Half>(args, dest, MulAddType::Nondeterministic)?
648+
}
649+
sym::fmuladdf32 => {
650+
self.float_muladd_intrinsic::<Single>(args, dest, MulAddType::Nondeterministic)?
651+
}
652+
sym::fmuladdf64 => {
653+
self.float_muladd_intrinsic::<Double>(args, dest, MulAddType::Nondeterministic)?
654+
}
655+
sym::fmuladdf128 => {
656+
self.float_muladd_intrinsic::<Quad>(args, dest, MulAddType::Nondeterministic)?
657+
}
641658

642659
// Unsupported intrinsic: skip the return_to_block below.
643660
_ => return interp_ok(false),
@@ -1038,40 +1055,41 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
10381055
interp_ok(())
10391056
}
10401057

1041-
fn fma_intrinsic<F>(
1042-
&mut self,
1043-
args: &[OpTy<'tcx, M::Provenance>],
1044-
dest: &PlaceTy<'tcx, M::Provenance>,
1045-
) -> InterpResult<'tcx, ()>
1058+
fn float_muladd<F>(
1059+
&self,
1060+
a: Scalar<M::Provenance>,
1061+
b: Scalar<M::Provenance>,
1062+
c: Scalar<M::Provenance>,
1063+
typ: MulAddType,
1064+
) -> InterpResult<'tcx, Scalar<M::Provenance>>
10461065
where
10471066
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
10481067
{
1049-
let a: F = self.read_scalar(&args[0])?.to_float()?;
1050-
let b: F = self.read_scalar(&args[1])?.to_float()?;
1051-
let c: F = self.read_scalar(&args[2])?.to_float()?;
1068+
let a: F = a.to_float()?;
1069+
let b: F = b.to_float()?;
1070+
let c: F = c.to_float()?;
1071+
1072+
let fuse = typ == MulAddType::Fused || M::float_fuse_mul_add(self);
10521073

1053-
let res = a.mul_add(b, c).value;
1074+
let res = if fuse { a.mul_add(b, c).value } else { ((a * b).value + c).value };
10541075
let res = self.adjust_nan(res, &[a, b, c]);
1055-
self.write_scalar(res, dest)?;
1056-
interp_ok(())
1076+
interp_ok(res.into())
10571077
}
10581078

10591079
fn float_muladd_intrinsic<F>(
10601080
&mut self,
10611081
args: &[OpTy<'tcx, M::Provenance>],
10621082
dest: &PlaceTy<'tcx, M::Provenance>,
1083+
typ: MulAddType,
10631084
) -> InterpResult<'tcx, ()>
10641085
where
10651086
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
10661087
{
1067-
let a: F = self.read_scalar(&args[0])?.to_float()?;
1068-
let b: F = self.read_scalar(&args[1])?.to_float()?;
1069-
let c: F = self.read_scalar(&args[2])?.to_float()?;
1070-
1071-
let fuse = M::float_fuse_mul_add(self);
1088+
let a = self.read_scalar(&args[0])?;
1089+
let b = self.read_scalar(&args[1])?;
1090+
let c = self.read_scalar(&args[2])?;
10721091

1073-
let res = if fuse { a.mul_add(b, c).value } else { ((a * b).value + c).value };
1074-
let res = self.adjust_nan(res, &[a, b, c]);
1092+
let res = self.float_muladd::<F>(a, b, c, typ)?;
10751093
self.write_scalar(res, dest)?;
10761094
interp_ok(())
10771095
}

compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use either::Either;
22
use rustc_abi::Endian;
3+
use rustc_apfloat::ieee::{Double, Single};
34
use rustc_apfloat::{Float, Round};
45
use rustc_middle::mir::interpret::{InterpErrorKind, UndefinedBehaviorInfo};
56
use rustc_middle::ty::FloatTy;
@@ -8,8 +9,8 @@ use rustc_span::{Symbol, sym};
89
use tracing::trace;
910

1011
use super::{
11-
ImmTy, InterpCx, InterpResult, Machine, OpTy, PlaceTy, Provenance, Scalar, Size, interp_ok,
12-
throw_ub_format,
12+
ImmTy, InterpCx, InterpResult, Machine, MulAddType, OpTy, PlaceTy, Provenance, Scalar, Size,
13+
interp_ok, throw_ub_format,
1314
};
1415
use crate::interpret::Writeable;
1516

@@ -701,6 +702,43 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
701702
};
702703
}
703704
}
705+
sym::simd_fma | sym::simd_relaxed_fma => {
706+
// `simd_fma` should always deterministically use `mul_add`, whereas `relaxed_fma`
707+
// is non-deterministic, and can use either `mul_add` or `a * b + c`
708+
let typ = match intrinsic_name {
709+
sym::simd_fma => MulAddType::Fused,
710+
sym::simd_relaxed_fma => MulAddType::Nondeterministic,
711+
_ => unreachable!(),
712+
};
713+
714+
let (a, a_len) = self.project_to_simd(&args[0])?;
715+
let (b, b_len) = self.project_to_simd(&args[1])?;
716+
let (c, c_len) = self.project_to_simd(&args[2])?;
717+
let (dest, dest_len) = self.project_to_simd(&dest)?;
718+
719+
assert_eq!(dest_len, a_len);
720+
assert_eq!(dest_len, b_len);
721+
assert_eq!(dest_len, c_len);
722+
723+
for i in 0..dest_len {
724+
let a = self.read_scalar(&self.project_index(&a, i)?)?;
725+
let b = self.read_scalar(&self.project_index(&b, i)?)?;
726+
let c = self.read_scalar(&self.project_index(&c, i)?)?;
727+
let dest = self.project_index(&dest, i)?;
728+
729+
let ty::Float(float_ty) = dest.layout.ty.kind() else {
730+
span_bug!(self.cur_span(), "{} operand is not a float", intrinsic_name)
731+
};
732+
733+
let val = match float_ty {
734+
FloatTy::F16 => unimplemented!("f16_f128"),
735+
FloatTy::F32 => self.float_muladd::<Single>(a, b, c, typ)?,
736+
FloatTy::F64 => self.float_muladd::<Double>(a, b, c, typ)?,
737+
FloatTy::F128 => unimplemented!("f16_f128"),
738+
};
739+
self.write_scalar(val, &dest)?;
740+
}
741+
}
704742

705743
// Unsupported intrinsic: skip the return_to_block below.
706744
_ => return interp_ok(false),

compiler/rustc_const_eval/src/interpret/machine.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ pub trait Machine<'tcx>: Sized {
290290
}
291291

292292
/// Determines whether the `fmuladd` intrinsics fuse the multiply-add or use separate operations.
293-
fn float_fuse_mul_add(_ecx: &mut InterpCx<'tcx, Self>) -> bool;
293+
fn float_fuse_mul_add(_ecx: &InterpCx<'tcx, Self>) -> bool;
294294

295295
/// Called before a basic block terminator is executed.
296296
#[inline]
@@ -676,7 +676,7 @@ pub macro compile_time_machine(<$tcx: lifetime>) {
676676
}
677677

678678
#[inline(always)]
679-
fn float_fuse_mul_add(_ecx: &mut InterpCx<$tcx, Self>) -> bool {
679+
fn float_fuse_mul_add(_ecx: &InterpCx<$tcx, Self>) -> bool {
680680
true
681681
}
682682

src/tools/miri/src/intrinsics/simd.rs

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use rand::Rng;
2-
use rustc_apfloat::Float;
31
use rustc_middle::ty;
42
use rustc_middle::ty::FloatTy;
53

@@ -83,62 +81,6 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
8381
this.write_scalar(val, &dest)?;
8482
}
8583
}
86-
"fma" | "relaxed_fma" => {
87-
let [a, b, c] = check_intrinsic_arg_count(args)?;
88-
let (a, a_len) = this.project_to_simd(a)?;
89-
let (b, b_len) = this.project_to_simd(b)?;
90-
let (c, c_len) = this.project_to_simd(c)?;
91-
let (dest, dest_len) = this.project_to_simd(dest)?;
92-
93-
assert_eq!(dest_len, a_len);
94-
assert_eq!(dest_len, b_len);
95-
assert_eq!(dest_len, c_len);
96-
97-
for i in 0..dest_len {
98-
let a = this.read_scalar(&this.project_index(&a, i)?)?;
99-
let b = this.read_scalar(&this.project_index(&b, i)?)?;
100-
let c = this.read_scalar(&this.project_index(&c, i)?)?;
101-
let dest = this.project_index(&dest, i)?;
102-
103-
let fuse: bool = intrinsic_name == "fma"
104-
|| (this.machine.float_nondet && this.machine.rng.get_mut().random());
105-
106-
// Works for f32 and f64.
107-
// FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
108-
let ty::Float(float_ty) = dest.layout.ty.kind() else {
109-
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
110-
};
111-
let val = match float_ty {
112-
FloatTy::F16 => unimplemented!("f16_f128"),
113-
FloatTy::F32 => {
114-
let a = a.to_f32()?;
115-
let b = b.to_f32()?;
116-
let c = c.to_f32()?;
117-
let res = if fuse {
118-
a.mul_add(b, c).value
119-
} else {
120-
((a * b).value + c).value
121-
};
122-
let res = this.adjust_nan(res, &[a, b, c]);
123-
Scalar::from(res)
124-
}
125-
FloatTy::F64 => {
126-
let a = a.to_f64()?;
127-
let b = b.to_f64()?;
128-
let c = c.to_f64()?;
129-
let res = if fuse {
130-
a.mul_add(b, c).value
131-
} else {
132-
((a * b).value + c).value
133-
};
134-
let res = this.adjust_nan(res, &[a, b, c]);
135-
Scalar::from(res)
136-
}
137-
FloatTy::F128 => unimplemented!("f16_f128"),
138-
};
139-
this.write_scalar(val, &dest)?;
140-
}
141-
}
14284
"expose_provenance" => {
14385
let [op] = check_intrinsic_arg_count(args)?;
14486
let (op, op_len) = this.project_to_simd(op)?;

src/tools/miri/src/machine.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,8 +1324,8 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> {
13241324
}
13251325

13261326
#[inline(always)]
1327-
fn float_fuse_mul_add(ecx: &mut InterpCx<'tcx, Self>) -> bool {
1328-
ecx.machine.float_nondet && ecx.machine.rng.get_mut().random()
1327+
fn float_fuse_mul_add(ecx: &InterpCx<'tcx, Self>) -> bool {
1328+
ecx.machine.float_nondet && ecx.machine.rng.borrow_mut().random()
13291329
}
13301330

13311331
#[inline(always)]

0 commit comments

Comments
 (0)