Skip to content

Commit b228256

Browse files
authored
[ARM] Introduce intrinsics for MVE fma under strict-fp. (#169771)
Similar to #169156, this adds an @arm.mve.fma intrinsic for strict-fp. A Builder class is added to act as the common subclass of IRBuilder and IRInt.
1 parent ce2c081 commit b228256

File tree

7 files changed

+839
-342
lines changed

7 files changed

+839
-342
lines changed

clang/include/clang/Basic/arm_mve.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ multiclass FMA<bit add> {
167167
// second multiply input.
168168
defvar m2_cg = !if(add, (id $m2), (fneg $m2));
169169

170-
defvar unpred_cg = (IRIntBase<"fma", [Vector]> $m1, m2_cg, $addend);
170+
defvar fma = strictFPAlt<IRIntBase<"fma", [Vector]>,
171+
IRInt<"fma", [Vector]>>;
172+
defvar unpred_cg = (fma $m1, m2_cg, $addend);
171173
defvar pred_cg = (IRInt<"fma_predicated", [Vector, Predicate]>
172174
$m1, m2_cg, $addend, $pred);
173175

@@ -723,7 +725,7 @@ multiclass compare_with_pred<string condname, dag arguments,
723725
NameOverride<"vcmp" # condname # "q_m" # suffix>;
724726
}
725727

726-
multiclass compare<string condname, IRBuilder cmpop> {
728+
multiclass compare<string condname, Builder cmpop> {
727729
// Make all four variants of a comparison: the vector/vector and
728730
// vector/scalar forms, each using compare_with_pred to make a
729731
// predicated and unpredicated version.

clang/include/clang/Basic/arm_mve_defs.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class IRBuilderAddrParam<int index_> : IRBuilderParam<index_>;
3434
class IRBuilderIntParam<int index_, string type_> : IRBuilderParam<index_> {
3535
string type = type_;
3636
}
37-
class IRBuilderBase {
37+
class Builder {}
38+
class IRBuilderBase : Builder {
3839
// The prefix of the function call, including an open parenthesis.
3940
string prefix;
4041

@@ -166,7 +167,7 @@ def address;
166167
// Another node class you can use in the codegen dag. This one corresponds to
167168
// an IR intrinsic function, which has to be specialized to a particular list
168169
// of types.
169-
class IRIntBase<string name_, list<Type> params_ = [], bit appendKind_ = 0> {
170+
class IRIntBase<string name_, list<Type> params_ = [], bit appendKind_ = 0> : Builder {
170171
string intname = name_; // base name of the intrinsic
171172
list<Type> params = params_; // list of parameter types
172173

@@ -214,8 +215,8 @@ def bitsize;
214215

215216
// strictFPAlt allows a node to have different code generation under strict-fp.
216217
// TODO: The standard node can be IRBuilderBase or IRIntBase.
217-
class strictFPAlt<IRBuilderBase standard_, IRIntBase strictfp_> {
218-
IRBuilderBase standard = standard_;
218+
class strictFPAlt<Builder standard_, IRIntBase strictfp_> : Builder {
219+
Builder standard = standard_;
219220
IRIntBase strictfp = strictfp_;
220221
}
221222

clang/test/CodeGen/arm-mve-intrinsics/ternary.c

Lines changed: 692 additions & 320 deletions
Large diffs are not rendered by default.

clang/utils/TableGen/MveEmitter.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,9 @@ Result::Ptr EmitterBase::getCodeForDag(const DagInit *D,
12601260
for (unsigned i = 0, e = D->getNumArgs(); i < e; ++i)
12611261
Args.push_back(getCodeForDagArg(D, i, Scope, Param));
12621262

1263-
auto GenIRBuilderBase = [&](const Record *Op) {
1263+
auto GenIRBuilderBase = [&](const Record *Op) -> Result::Ptr {
1264+
assert(Op->isSubClassOf("IRBuilderBase") &&
1265+
"Expected IRBuilderBase in GenIRBuilderBase\n");
12641266
std::set<unsigned> AddressArgs;
12651267
std::map<unsigned, std::string> IntegerArgs;
12661268
for (const Record *sp : Op->getValueAsListOfDefs("special_params")) {
@@ -1274,7 +1276,9 @@ Result::Ptr EmitterBase::getCodeForDag(const DagInit *D,
12741276
return std::make_shared<IRBuilderResult>(Op->getValueAsString("prefix"),
12751277
Args, AddressArgs, IntegerArgs);
12761278
};
1277-
auto GenIRIntBase = [&](const Record *Op) {
1279+
auto GenIRIntBase = [&](const Record *Op) -> Result::Ptr {
1280+
assert(Op->isSubClassOf("IRIntBase") &&
1281+
"Expected IRIntBase in GenIRIntBase\n");
12781282
std::vector<const Type *> ParamTypes;
12791283
for (const Record *RParam : Op->getValueAsListOfDefs("params"))
12801284
ParamTypes.push_back(getType(RParam, Param));
@@ -1289,8 +1293,11 @@ Result::Ptr EmitterBase::getCodeForDag(const DagInit *D,
12891293
} else if (Op->isSubClassOf("IRIntBase")) {
12901294
return GenIRIntBase(Op);
12911295
} else if (Op->isSubClassOf("strictFPAlt")) {
1292-
auto Standard = GenIRBuilderBase(Op->getValueAsDef("standard"));
1293-
auto StrictFp = GenIRIntBase(Op->getValueAsDef("strictfp"));
1296+
auto StardardBuilder = Op->getValueAsDef("standard");
1297+
Result::Ptr Standard = StardardBuilder->isSubClassOf("IRBuilder")
1298+
? GenIRBuilderBase(StardardBuilder)
1299+
: GenIRIntBase(StardardBuilder);
1300+
Result::Ptr StrictFp = GenIRIntBase(Op->getValueAsDef("strictfp"));
12941301
return std::make_shared<StrictFpAltResult>(Standard, StrictFp);
12951302
} else {
12961303
PrintFatalError("Unsupported dag node " + Op->getName());

llvm/include/llvm/IR/IntrinsicsARM.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,9 @@ def int_arm_mve_vqmovn_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
13621362
llvm_i32_ty /* unsigned output */, llvm_i32_ty /* unsigned input */,
13631363
llvm_i32_ty /* top half */, llvm_anyvector_ty /* pred */], [IntrNoMem]>;
13641364

1365+
def int_arm_mve_fma: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
1366+
[LLVMMatchType<0> /* mult op #1 */, LLVMMatchType<0> /* mult op #2 */,
1367+
LLVMMatchType<0> /* addend */], [IntrNoMem]>;
13651368
// fma_predicated returns the add operand for disabled lanes.
13661369
def int_arm_mve_fma_predicated: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
13671370
[LLVMMatchType<0> /* mult op #1 */, LLVMMatchType<0> /* mult op #2 */,

llvm/lib/Target/ARM/ARMInstrMVE.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3723,6 +3723,10 @@ multiclass MVE_VFMA_fp_multi<string iname, bit fms, MVEVectorVTInfo VTI> {
37233723
if fms then {
37243724
def : Pat<(VTI.Vec (fma (fneg m1), m2, add)),
37253725
(Inst $add, $m1, $m2)>;
3726+
def : Pat<(VTI.Vec (int_arm_mve_fma (fneg m1), m2, add)),
3727+
(Inst $add, $m1, $m2)>;
3728+
def : Pat<(VTI.Vec (int_arm_mve_fma m1, (fneg m2), add)),
3729+
(Inst $add, $m1, $m2)>;
37263730
def : Pat<(VTI.Vec (vselect (VTI.Pred VCCR:$pred),
37273731
(VTI.Vec (fma (fneg m1), m2, add)),
37283732
add)),
@@ -3734,6 +3738,8 @@ multiclass MVE_VFMA_fp_multi<string iname, bit fms, MVEVectorVTInfo VTI> {
37343738
} else {
37353739
def : Pat<(VTI.Vec (fma m1, m2, add)),
37363740
(Inst $add, $m1, $m2)>;
3741+
def : Pat<(VTI.Vec (int_arm_mve_fma m1, m2, add)),
3742+
(Inst $add, $m1, $m2)>;
37373743
def : Pat<(VTI.Vec (vselect (VTI.Pred VCCR:$pred),
37383744
(VTI.Vec (fma m1, m2, add)),
37393745
add)),
@@ -5672,6 +5678,8 @@ multiclass MVE_VFMA_qr_multi<string iname, MVEVectorVTInfo VTI,
56725678
if scalar_addend then {
56735679
def : Pat<(VTI.Vec (fma v1, v2, vs)),
56745680
(VTI.Vec (Inst v1, v2, is))>;
5681+
def : Pat<(VTI.Vec (int_arm_mve_fma v1, v2, vs)),
5682+
(VTI.Vec (Inst v1, v2, is))>;
56755683
def : Pat<(VTI.Vec (vselect (VTI.Pred VCCR:$pred),
56765684
(VTI.Vec (fma v1, v2, vs)),
56775685
v1)),
@@ -5681,6 +5689,10 @@ multiclass MVE_VFMA_qr_multi<string iname, MVEVectorVTInfo VTI,
56815689
(VTI.Vec (Inst v2, v1, is))>;
56825690
def : Pat<(VTI.Vec (fma vs, v1, v2)),
56835691
(VTI.Vec (Inst v2, v1, is))>;
5692+
def : Pat<(VTI.Vec (int_arm_mve_fma v1, vs, v2)),
5693+
(VTI.Vec (Inst v2, v1, is))>;
5694+
def : Pat<(VTI.Vec (int_arm_mve_fma vs, v1, v2)),
5695+
(VTI.Vec (Inst v2, v1, is))>;
56845696
def : Pat<(VTI.Vec (vselect (VTI.Pred VCCR:$pred),
56855697
(VTI.Vec (fma vs, v2, v1)),
56865698
v1)),

llvm/test/CodeGen/Thumb2/mve-intrinsics/strict-intrinsics.ll

Lines changed: 112 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
22
; RUN: llc -mtriple=thumbv8.1m.main -mattr=+mve.fp -o - %s | FileCheck %s
33

4-
define arm_aapcs_vfpcc <8 x half> @test_vaddq_f16(<8 x half> %a, <8 x half> %b) {
4+
define arm_aapcs_vfpcc <8 x half> @test_vaddq_f16(<8 x half> %a, <8 x half> %b) #0 {
55
; CHECK-LABEL: test_vaddq_f16:
66
; CHECK: @ %bb.0: @ %entry
77
; CHECK-NEXT: vadd.f16 q0, q0, q1
@@ -11,7 +11,7 @@ entry:
1111
ret <8 x half> %0
1212
}
1313

14-
define arm_aapcs_vfpcc <4 x float> @test_vaddq_f32(<4 x float> %a, <4 x float> %b) {
14+
define arm_aapcs_vfpcc <4 x float> @test_vaddq_f32(<4 x float> %a, <4 x float> %b) #0 {
1515
; CHECK-LABEL: test_vaddq_f32:
1616
; CHECK: @ %bb.0: @ %entry
1717
; CHECK-NEXT: vadd.f32 q0, q0, q1
@@ -21,7 +21,7 @@ entry:
2121
ret <4 x float> %0
2222
}
2323

24-
define arm_aapcs_vfpcc <8 x half> @test_vsubq_f16(<8 x half> %a, <8 x half> %b) {
24+
define arm_aapcs_vfpcc <8 x half> @test_vsubq_f16(<8 x half> %a, <8 x half> %b) #0 {
2525
; CHECK-LABEL: test_vsubq_f16:
2626
; CHECK: @ %bb.0: @ %entry
2727
; CHECK-NEXT: vsub.f16 q0, q0, q1
@@ -31,7 +31,7 @@ entry:
3131
ret <8 x half> %0
3232
}
3333

34-
define arm_aapcs_vfpcc <4 x float> @test_vsubq_f32(<4 x float> %a, <4 x float> %b) {
34+
define arm_aapcs_vfpcc <4 x float> @test_vsubq_f32(<4 x float> %a, <4 x float> %b) #0 {
3535
; CHECK-LABEL: test_vsubq_f32:
3636
; CHECK: @ %bb.0: @ %entry
3737
; CHECK-NEXT: vsub.f32 q0, q0, q1
@@ -41,7 +41,7 @@ entry:
4141
ret <4 x float> %0
4242
}
4343

44-
define arm_aapcs_vfpcc <8 x half> @test_vmulq_f16(<8 x half> %a, <8 x half> %b) {
44+
define arm_aapcs_vfpcc <8 x half> @test_vmulq_f16(<8 x half> %a, <8 x half> %b) #0 {
4545
; CHECK-LABEL: test_vmulq_f16:
4646
; CHECK: @ %bb.0: @ %entry
4747
; CHECK-NEXT: vmul.f16 q0, q0, q1
@@ -51,7 +51,7 @@ entry:
5151
ret <8 x half> %0
5252
}
5353

54-
define arm_aapcs_vfpcc <4 x float> @test_vmulq_f32(<4 x float> %a, <4 x float> %b) {
54+
define arm_aapcs_vfpcc <4 x float> @test_vmulq_f32(<4 x float> %a, <4 x float> %b) #0 {
5555
; CHECK-LABEL: test_vmulq_f32:
5656
; CHECK: @ %bb.0: @ %entry
5757
; CHECK-NEXT: vmul.f32 q0, q0, q1
@@ -64,7 +64,7 @@ entry:
6464

6565

6666

67-
define arm_aapcs_vfpcc <8 x half> @test_vaddq_f16_splat(<8 x half> %a, half %b) {
67+
define arm_aapcs_vfpcc <8 x half> @test_vaddq_f16_splat(<8 x half> %a, half %b) #0 {
6868
; CHECK-LABEL: test_vaddq_f16_splat:
6969
; CHECK: @ %bb.0: @ %entry
7070
; CHECK-NEXT: vmov.f16 r0, s4
@@ -77,7 +77,7 @@ entry:
7777
ret <8 x half> %0
7878
}
7979

80-
define arm_aapcs_vfpcc <4 x float> @test_vaddq_f32_splat(<4 x float> %a, float %b) {
80+
define arm_aapcs_vfpcc <4 x float> @test_vaddq_f32_splat(<4 x float> %a, float %b) #0 {
8181
; CHECK-LABEL: test_vaddq_f32_splat:
8282
; CHECK: @ %bb.0: @ %entry
8383
; CHECK-NEXT: vmov r0, s4
@@ -90,7 +90,7 @@ entry:
9090
ret <4 x float> %0
9191
}
9292

93-
define arm_aapcs_vfpcc <8 x half> @test_vsubq_f16_splat(<8 x half> %a, half %b) {
93+
define arm_aapcs_vfpcc <8 x half> @test_vsubq_f16_splat(<8 x half> %a, half %b) #0 {
9494
; CHECK-LABEL: test_vsubq_f16_splat:
9595
; CHECK: @ %bb.0: @ %entry
9696
; CHECK-NEXT: vmov.f16 r0, s4
@@ -103,7 +103,7 @@ entry:
103103
ret <8 x half> %0
104104
}
105105

106-
define arm_aapcs_vfpcc <4 x float> @test_vsubq_f32_splat(<4 x float> %a, float %b) {
106+
define arm_aapcs_vfpcc <4 x float> @test_vsubq_f32_splat(<4 x float> %a, float %b) #0 {
107107
; CHECK-LABEL: test_vsubq_f32_splat:
108108
; CHECK: @ %bb.0: @ %entry
109109
; CHECK-NEXT: vmov r0, s4
@@ -116,7 +116,7 @@ entry:
116116
ret <4 x float> %0
117117
}
118118

119-
define arm_aapcs_vfpcc <8 x half> @test_vmulq_f16_splat(<8 x half> %a, half %b) {
119+
define arm_aapcs_vfpcc <8 x half> @test_vmulq_f16_splat(<8 x half> %a, half %b) #0 {
120120
; CHECK-LABEL: test_vmulq_f16_splat:
121121
; CHECK: @ %bb.0: @ %entry
122122
; CHECK-NEXT: vmov.f16 r0, s4
@@ -129,7 +129,7 @@ entry:
129129
ret <8 x half> %0
130130
}
131131

132-
define arm_aapcs_vfpcc <4 x float> @test_vmulq_f32_splat(<4 x float> %a, float %b) {
132+
define arm_aapcs_vfpcc <4 x float> @test_vmulq_f32_splat(<4 x float> %a, float %b) #0 {
133133
; CHECK-LABEL: test_vmulq_f32_splat:
134134
; CHECK: @ %bb.0: @ %entry
135135
; CHECK-NEXT: vmov r0, s4
@@ -141,3 +141,103 @@ entry:
141141
%0 = tail call <4 x float> @llvm.arm.mve.vmul.v4f32(<4 x float> %a, <4 x float> %s)
142142
ret <4 x float> %0
143143
}
144+
145+
define arm_aapcs_vfpcc <4 x float> @fma_v4f32(<4 x float> %dst, <4 x float> %s1, <4 x float> %s2) #0 {
146+
; CHECK-LABEL: fma_v4f32:
147+
; CHECK: @ %bb.0: @ %entry
148+
; CHECK-NEXT: vfma.f32 q0, q1, q2
149+
; CHECK-NEXT: bx lr
150+
entry:
151+
%0 = tail call <4 x float> @llvm.arm.mve.fma.v4f32(<4 x float> %s1, <4 x float> %s2, <4 x float> %dst)
152+
ret <4 x float> %0
153+
}
154+
155+
define arm_aapcs_vfpcc <8 x half> @fma_v8f16(<8 x half> %dst, <8 x half> %s1, <8 x half> %s2) #0 {
156+
; CHECK-LABEL: fma_v8f16:
157+
; CHECK: @ %bb.0: @ %entry
158+
; CHECK-NEXT: vfma.f16 q0, q1, q2
159+
; CHECK-NEXT: bx lr
160+
entry:
161+
%0 = tail call <8 x half> @llvm.arm.mve.fma.v8f16(<8 x half> %s1, <8 x half> %s2, <8 x half> %dst)
162+
ret <8 x half> %0
163+
}
164+
165+
define arm_aapcs_vfpcc <4 x float> @fma_n_v8f16(<4 x float> %s1, <4 x float> %s2, float %s3) #0 {
166+
; CHECK-LABEL: fma_n_v8f16:
167+
; CHECK: @ %bb.0: @ %entry
168+
; CHECK-NEXT: vmov r0, s8
169+
; CHECK-NEXT: vfma.f32 q0, q1, r0
170+
; CHECK-NEXT: bx lr
171+
entry:
172+
%i = insertelement <4 x float> poison, float %s3, i32 0
173+
%sp = shufflevector <4 x float> %i, <4 x float> poison, <4 x i32> zeroinitializer
174+
%0 = tail call <4 x float> @llvm.arm.mve.fma.v4f32(<4 x float> %s2, <4 x float> %sp, <4 x float> %s1)
175+
ret <4 x float> %0
176+
}
177+
178+
define arm_aapcs_vfpcc <8 x half> @fma_n_v4f32(<8 x half> %s1, <8 x half> %s2, half %s3) #0 {
179+
; CHECK-LABEL: fma_n_v4f32:
180+
; CHECK: @ %bb.0: @ %entry
181+
; CHECK-NEXT: vmov.f16 r0, s8
182+
; CHECK-NEXT: vfma.f16 q0, q1, r0
183+
; CHECK-NEXT: bx lr
184+
entry:
185+
%i = insertelement <8 x half> poison, half %s3, i32 0
186+
%sp = shufflevector <8 x half> %i, <8 x half> poison, <8 x i32> zeroinitializer
187+
%0 = tail call <8 x half> @llvm.arm.mve.fma.v8f16(<8 x half> %s2, <8 x half> %sp, <8 x half> %s1)
188+
ret <8 x half> %0
189+
}
190+
191+
define arm_aapcs_vfpcc <4 x float> @fms_v4f32(<4 x float> %dst, <4 x float> %s1, <4 x float> %s2) #0 {
192+
; CHECK-LABEL: fms_v4f32:
193+
; CHECK: @ %bb.0: @ %entry
194+
; CHECK-NEXT: vfms.f32 q0, q1, q2
195+
; CHECK-NEXT: bx lr
196+
entry:
197+
%c = fneg <4 x float> %s1
198+
%0 = tail call <4 x float> @llvm.arm.mve.fma.v4f32(<4 x float> %c, <4 x float> %s2, <4 x float> %dst)
199+
ret <4 x float> %0
200+
}
201+
202+
define arm_aapcs_vfpcc <8 x half> @fms_v8f16(<8 x half> %dst, <8 x half> %s1, <8 x half> %s2) #0 {
203+
; CHECK-LABEL: fms_v8f16:
204+
; CHECK: @ %bb.0: @ %entry
205+
; CHECK-NEXT: vfms.f16 q0, q1, q2
206+
; CHECK-NEXT: bx lr
207+
entry:
208+
%c = fneg <8 x half> %s1
209+
%0 = tail call <8 x half> @llvm.arm.mve.fma.v8f16(<8 x half> %c, <8 x half> %s2, <8 x half> %dst)
210+
ret <8 x half> %0
211+
}
212+
213+
define arm_aapcs_vfpcc <4 x float> @fms_n_v8f16(<4 x float> %s1, <4 x float> %s2, float %s3) #0 {
214+
; CHECK-LABEL: fms_n_v8f16:
215+
; CHECK: @ %bb.0: @ %entry
216+
; CHECK-NEXT: vmov r0, s8
217+
; CHECK-NEXT: vdup.32 q2, r0
218+
; CHECK-NEXT: vfms.f32 q0, q1, q2
219+
; CHECK-NEXT: bx lr
220+
entry:
221+
%c = fneg <4 x float> %s2
222+
%i = insertelement <4 x float> poison, float %s3, i32 0
223+
%sp = shufflevector <4 x float> %i, <4 x float> poison, <4 x i32> zeroinitializer
224+
%0 = tail call <4 x float> @llvm.arm.mve.fma.v4f32(<4 x float> %c, <4 x float> %sp, <4 x float> %s1)
225+
ret <4 x float> %0
226+
}
227+
228+
define arm_aapcs_vfpcc <8 x half> @fms_n_v4f32(<8 x half> %s1, <8 x half> %s2, half %s3) #0 {
229+
; CHECK-LABEL: fms_n_v4f32:
230+
; CHECK: @ %bb.0: @ %entry
231+
; CHECK-NEXT: vmov.f16 r0, s8
232+
; CHECK-NEXT: vdup.16 q2, r0
233+
; CHECK-NEXT: vfms.f16 q0, q1, q2
234+
; CHECK-NEXT: bx lr
235+
entry:
236+
%c = fneg <8 x half> %s2
237+
%i = insertelement <8 x half> poison, half %s3, i32 0
238+
%sp = shufflevector <8 x half> %i, <8 x half> poison, <8 x i32> zeroinitializer
239+
%0 = tail call <8 x half> @llvm.arm.mve.fma.v8f16(<8 x half> %c, <8 x half> %sp, <8 x half> %s1)
240+
ret <8 x half> %0
241+
}
242+
243+
attributes #0 = { strictfp }

0 commit comments

Comments
 (0)