Skip to content

Commit 65cf01c

Browse files
vsemenov368jsji
authored andcommitted
Implement SPV_INTEL_sigmoid extension (#3420)
Spec link: #20504 Original commit: KhronosGroup/SPIRV-LLVM-Translator@a3a5de742179b34
1 parent 4019d91 commit 65cf01c

File tree

7 files changed

+166
-0
lines changed

7 files changed

+166
-0
lines changed

llvm-spirv/include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,4 @@ EXT(SPV_INTEL_function_variants)
8383
EXT(SPV_INTEL_shader_atomic_bfloat16)
8484
EXT(SPV_EXT_float8)
8585
EXT(SPV_INTEL_predicated_io)
86+
EXT(SPV_INTEL_sigmoid)

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4496,5 +4496,64 @@ class SPIRVPredicatedIOINTELInst : public SPIRVInstTemplateBase {
44964496
_SPIRV_OP(PredicatedLoad, true, 6, true)
44974497
_SPIRV_OP(PredicatedStore, false, 4, true)
44984498
#undef _SPIRV_OP
4499+
4500+
template <Op OC> class SPIRVFSigmoidINTELInstBase : public SPIRVUnaryInst<OC> {
4501+
protected:
4502+
SPIRVCapVec getRequiredCapability() const override {
4503+
return getVec(internal::CapabilitySigmoidINTEL);
4504+
}
4505+
4506+
std::optional<ExtensionID> getRequiredExtension() const override {
4507+
return ExtensionID::SPV_INTEL_sigmoid;
4508+
}
4509+
4510+
void validate() const override {
4511+
SPIRVUnaryInst<OC>::validate();
4512+
4513+
SPIRVType *ResCompTy = this->getType();
4514+
SPIRVWord ResCompCount = 1;
4515+
if (ResCompTy->isTypeVector()) {
4516+
ResCompCount = ResCompTy->getVectorComponentCount();
4517+
ResCompTy = ResCompTy->getVectorComponentType();
4518+
}
4519+
4520+
// validate is a const method, whilst getOperand is non-const method
4521+
// because it may call a method of class Module that may modify LiteralMap
4522+
// of Module field. That modification is not impacting validate method for
4523+
// these instructions, so const_cast is safe here.
4524+
using SPVFSigmoidTy = SPIRVFSigmoidINTELInstBase<OC>;
4525+
const SPIRVValue *Input = const_cast<SPVFSigmoidTy *>(this)->getOperand(0);
4526+
4527+
SPIRVType *InCompTy = Input->getType();
4528+
SPIRVWord InCompCount = 1;
4529+
if (InCompTy->isTypeVector()) {
4530+
InCompCount = InCompTy->getVectorComponentCount();
4531+
InCompTy = InCompTy->getVectorComponentType();
4532+
}
4533+
4534+
auto InstName = OpCodeNameMap::map(OC);
4535+
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
4536+
4537+
SPVErrLog.checkError(
4538+
ResCompTy->isTypeFloat(16) || ResCompTy->isTypeFloat(32) ||
4539+
ResCompTy->isTypeFloat(16, FPEncodingBFloat16KHR),
4540+
SPIRVEC_InvalidInstruction,
4541+
InstName + "\nResult value must be a scalar or vector of floating-point"
4542+
" 16-bit or 32-bit type\n");
4543+
SPVErrLog.checkError(
4544+
ResCompTy == InCompTy, SPIRVEC_InvalidInstruction,
4545+
InstName +
4546+
"\nInput type must have the same component type as result type\n");
4547+
SPVErrLog.checkError(
4548+
ResCompCount == InCompCount, SPIRVEC_InvalidInstruction,
4549+
InstName + "\nInput type must have the same number of components as "
4550+
"result type\n");
4551+
}
4552+
};
4553+
4554+
#define _SPIRV_OP(x, ...) \
4555+
typedef SPIRVFSigmoidINTELInstBase<internal::Op##x> SPIRV##x;
4556+
_SPIRV_OP(FSigmoidINTEL)
4557+
#undef _SPIRV_OP
44994558
} // namespace SPIRV
45004559
#endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
709709
add(CapabilityFloat8EXT, "Float8EXT");
710710
add(CapabilityFloat8CooperativeMatrixEXT, "Float8CooperativeMatrixEXT");
711711
add(internal::CapabilityPredicatedIOINTEL, "PredicatedIOINTEL");
712+
add(internal::CapabilitySigmoidINTEL, "SigmoidINTEL");
712713
}
713714
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
714715

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ _SPIRV_OP_INTERNAL(PredicatedLoadINTEL,
4545
internal::OpPredicatedLoadINTEL)
4646
_SPIRV_OP_INTERNAL(PredicatedStoreINTEL,
4747
internal::OpPredicatedStoreINTEL)
48+
_SPIRV_OP_INTERNAL(FSigmoidINTEL, internal::FSigmoidINTEL)

llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ enum InternalOp {
8888
IOpConvertHandleToImageINTEL = 6529,
8989
IOpConvertHandleToSamplerINTEL = 6530,
9090
IOpConvertHandleToSampledImageINTEL = 6531,
91+
IOpFSigmoidINTEL = 6168,
9192
IOpPrev = OpMax - 2,
9293
IOpForward
9394
};
@@ -108,6 +109,7 @@ enum InternalCapability {
108109
ICapabilityHWThreadQueryINTEL = 6134,
109110
ICapGlobalVariableDecorationsINTEL = 6146,
110111
ICapabilityTaskSequenceINTEL = 6162,
112+
ICapabilitySigmoidINTEL = 6167,
111113
ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192,
112114
ICapabilityBFloat16ArithmeticINTEL = 6226,
113115
ICapabilityCooperativeMatrixOffsetInstructionsINTEL = 6238,
@@ -219,6 +221,8 @@ _SPIRV_OP(Capability, PredicatedIOINTEL)
219221
_SPIRV_OP(Op, PredicatedLoadINTEL)
220222
_SPIRV_OP(Op, PredicatedStoreINTEL)
221223

224+
_SPIRV_OP(Capability, SigmoidINTEL)
225+
_SPIRV_OP(Op, FSigmoidINTEL)
222226
#undef _SPIRV_OP
223227

224228
constexpr SourceLanguage SourceLanguagePython =
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_sigmoid
3+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
; RUN: llvm-spirv %t.spv -o %t.rev.bc -r --spirv-target-env=SPV-IR
6+
; RUN: llvm-dis %t.rev.bc -o %t.rev.ll
7+
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
8+
9+
; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
10+
; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
11+
; CHECK-ERROR-NEXT: SPV_INTEL_sigmoid
12+
13+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
14+
target triple = "spir64-unknown-unknown"
15+
16+
; CHECK-SPIRV: Capability SigmoidINTEL
17+
; CHECK-SPIRV: Extension "SPV_INTEL_sigmoid"
18+
; CHECK-SPIRV: TypeFloat [[#FP16Ty:]] 16
19+
; CHECK-SPIRV: TypeVector [[#FP16v8Ty:]] [[#FP16Ty]] 8
20+
; CHECK-SPIRV: Constant [[#FP16Ty]] [[#CONST:]] 15360
21+
22+
; CHECK-SPIRV: FunctionParameter [[#FP16Ty]] [[FP16ValId:.*]]
23+
; CHECK-SPIRV: FunctionParameter [[#FP16v8Ty]] [[FP16v8ValId:.*]]
24+
25+
; CHECK-SPIRV: FSigmoidINTEL [[#FP16Ty]] [[#]] [[FP16ValId]]
26+
; CHECK-SPIRV: FSigmoidINTEL [[#FP16v8Ty]] [[#]] [[FP16v8ValId]]
27+
; CHECK-SPIRV: FSigmoidINTEL [[#FP16Ty]] [[#]] [[#CONST]]
28+
29+
; CHECK-LLVM: call spir_func half @_Z21__spirv_FSigmoidINTELDh(half
30+
; CHECK-LLVM: call spir_func <8 x half> @_Z21__spirv_FSigmoidINTELDv8_Dh(<8 x half>
31+
; CHECK-LLVM: call spir_func half @_Z21__spirv_FSigmoidINTELDh(half 0xH3C00)
32+
33+
define spir_func void @_Z2opffv8(half %a, <8 x half> %in) {
34+
%1 = tail call spir_func half @_Z21__spirv_FSigmoidINTELDh(half %a)
35+
%2 = tail call spir_func <8 x half> @_Z21__spirv_FSigmoidINTELDv8_Dh(<8 x half> %in)
36+
%3 = tail call spir_func half @_Z21__spirv_FSigmoidINTELDh(half 1.000000e+00)
37+
ret void
38+
}
39+
40+
declare spir_func half @_Z21__spirv_FSigmoidINTELDh(half)
41+
42+
declare spir_func <8 x half> @_Z21__spirv_FSigmoidINTELDv8_Dh(<8 x half>)
43+
44+
!opencl.spir.version = !{!0}
45+
!spirv.Source = !{!1}
46+
!llvm.ident = !{!2}
47+
48+
!0 = !{i32 1, i32 2}
49+
!1 = !{i32 4, i32 100000}
50+
!2 = !{!"clang version 16.0.0"}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_sigmoid
3+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
; RUN: llvm-spirv %t.spv -o %t.rev.bc -r --spirv-target-env=SPV-IR
6+
; RUN: llvm-dis %t.rev.bc -o %t.rev.ll
7+
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
8+
9+
; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
10+
; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
11+
; CHECK-ERROR-NEXT: SPV_INTEL_sigmoid
12+
13+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
14+
target triple = "spir64-unknown-unknown"
15+
16+
; CHECK-SPIRV: Capability SigmoidINTEL
17+
; CHECK-SPIRV: Extension "SPV_INTEL_sigmoid"
18+
; CHECK-SPIRV: TypeFloat [[#FP32Ty:]] 32
19+
; CHECK-SPIRV: TypeVector [[#FP32v8Ty:]] [[#FP32Ty]] 8
20+
; CHECK-SPIRV: Constant [[#FP32Ty]] [[#CONST:]] 1065353216
21+
22+
; CHECK-SPIRV: FunctionParameter [[#FP32Ty]] [[FP32ValId:.*]]
23+
; CHECK-SPIRV: FunctionParameter [[#FP32v8Ty]] [[FP32v8ValId:.*]]
24+
25+
; CHECK-SPIRV: FSigmoidINTEL [[#FP32Ty]] [[#]] [[FP32ValId]]
26+
; CHECK-SPIRV: FSigmoidINTEL [[#FP32v8Ty]] [[#]] [[FP32v8ValId]]
27+
; CHECK-SPIRV: FSigmoidINTEL [[#FP32Ty]] [[#]] [[#CONST]]
28+
29+
; CHECK-LLVM: call spir_func float @_Z21__spirv_FSigmoidINTELf(float
30+
; CHECK-LLVM: call spir_func <8 x float> @_Z21__spirv_FSigmoidINTELDv8_f(<8 x float>
31+
; CHECK-LLVM: call spir_func float @_Z21__spirv_FSigmoidINTELf(float 1.000000e+00)
32+
33+
define spir_func void @_Z2opffv8(float %a, <8 x float> %in) {
34+
%1 = tail call spir_func float @_Z21__spirv_FSigmoidINTELf(float %a)
35+
%2 = tail call spir_func <8 x float> @_Z21__spirv_FSigmoidINTELDv8_f(<8 x float> %in)
36+
%3 = tail call spir_func float @_Z21__spirv_FSigmoidINTELf(float 1.000000e+00)
37+
ret void
38+
}
39+
40+
declare spir_func float @_Z21__spirv_FSigmoidINTELf(float)
41+
42+
declare spir_func <8 x float> @_Z21__spirv_FSigmoidINTELDv8_f(<8 x float>)
43+
44+
!opencl.spir.version = !{!0}
45+
!spirv.Source = !{!1}
46+
!llvm.ident = !{!2}
47+
48+
!0 = !{i32 1, i32 2}
49+
!1 = !{i32 4, i32 100000}
50+
!2 = !{!"clang version 16.0.0"}

0 commit comments

Comments
 (0)