@@ -616,6 +616,56 @@ fn horizontal_bin_op<'tcx>(
616616 Ok ( ( ) )
617617}
618618
619+ /// Conditionally multiplies the packed floating-point elements in
620+ /// `left` and `right` using the high 4 bits in `imm`, sums the calculated
621+ /// products (up to 4), and conditionally stores the sum in `dest` using
622+ /// the low 4 bits of `imm`.
623+ fn conditional_dot_product < ' tcx > (
624+ this : & mut crate :: MiriInterpCx < ' _ , ' tcx > ,
625+ left : & OpTy < ' tcx , Provenance > ,
626+ right : & OpTy < ' tcx , Provenance > ,
627+ imm : & OpTy < ' tcx , Provenance > ,
628+ dest : & PlaceTy < ' tcx , Provenance > ,
629+ ) -> InterpResult < ' tcx , ( ) > {
630+ let ( left, left_len) = this. operand_to_simd ( left) ?;
631+ let ( right, right_len) = this. operand_to_simd ( right) ?;
632+ let ( dest, dest_len) = this. place_to_simd ( dest) ?;
633+
634+ assert_eq ! ( left_len, right_len) ;
635+ assert ! ( dest_len <= 4 ) ;
636+
637+ let imm = this. read_scalar ( imm) ?. to_u8 ( ) ?;
638+
639+ let element_layout = left. layout . field ( this, 0 ) ;
640+
641+ // Calculate dot product
642+ // Elements are floating point numbers, but we can use `from_int`
643+ // because the representation of 0.0 is all zero bits.
644+ let mut sum = ImmTy :: from_int ( 0u8 , element_layout) ;
645+ for i in 0 ..left_len {
646+ if imm & ( 1 << i. checked_add ( 4 ) . unwrap ( ) ) != 0 {
647+ let left = this. read_immediate ( & this. project_index ( & left, i) ?) ?;
648+ let right = this. read_immediate ( & this. project_index ( & right, i) ?) ?;
649+
650+ let mul = this. wrapping_binary_op ( mir:: BinOp :: Mul , & left, & right) ?;
651+ sum = this. wrapping_binary_op ( mir:: BinOp :: Add , & sum, & mul) ?;
652+ }
653+ }
654+
655+ // Write to destination (conditioned to imm)
656+ for i in 0 ..dest_len {
657+ let dest = this. project_index ( & dest, i) ?;
658+
659+ if imm & ( 1 << i) != 0 {
660+ this. write_immediate ( * sum, & dest) ?;
661+ } else {
662+ this. write_scalar ( Scalar :: from_int ( 0u8 , element_layout. size ) , & dest) ?;
663+ }
664+ }
665+
666+ Ok ( ( ) )
667+ }
668+
619669/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`.
620670fn bin_op_folded < ' tcx , T > (
621671 this : & crate :: MiriInterpCx < ' _ , ' tcx > ,
0 commit comments