Skip to content

Commit 28af544

Browse files
smk2007Sheil Kumar
andauthored
[DirectML] Broadcast NC-dims for Tensors A&B in DynamicQuantizeMatMul (microsoft#21298)
### Description [DirectML] Broadcast NC-dims for Tensors A&B in DynamicQuantizeMatMul The DynamicQuantizeMatMul allows input tensors in NCHW format, and DirectML requires that input tensors share the same batch and channel dimensions. Tensors A and B should be broadcast (if possible) to the corresponding output NC dims. ### Motivation and Context Certain models which use DynamicQuantizeMatMul hit a crash when the NC dims are intended to be broadcast. --------- Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
1 parent 20cd339 commit 28af544

File tree

1 file changed

+45
-5
lines changed

1 file changed

+45
-5
lines changed

onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,32 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
4040
kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0)
4141
);
4242
}
43+
MLOperatorTensorDataType ADatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::A).tensorDataType;
4344
MLOperatorTensorDataType BDatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::B).tensorDataType;
4445

46+
gsl::span<const uint32_t> outputSizes = m_outputTensorDescs[0].GetSizes();
4547
std::vector<uint32_t> ATensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::A);
46-
std::vector<uint32_t> ExpectedAScaleTensorShape = {1, 1, 1, 1};
47-
std::vector<uint32_t> ExpectedAZeroPointTensorShape = {1, 1, 1, 1};
48+
std::vector<uint32_t> BTensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::B);
49+
std::vector<uint32_t> ExpectedAScaleTensorShape(outputSizes.size(), 1);
50+
std::vector<uint32_t> ExpectedAZeroPointTensorShape(outputSizes.size(), 1);
51+
ML_CHECK_VALID_ARGUMENT(outputSizes.size() >= 4);
52+
ML_CHECK_VALID_ARGUMENT(ATensorShape.size() >= 2);
53+
ML_CHECK_VALID_ARGUMENT(BTensorShape.size() >= 2);
54+
ML_CHECK_VALID_ARGUMENT(ATensorShape.size() + 2 >= outputSizes.size());
55+
ML_CHECK_VALID_ARGUMENT(BTensorShape.size() + 2 >= outputSizes.size());
56+
std::vector<uint32_t> AShapeBroadcasted(outputSizes.begin(), outputSizes.end());
57+
std::copy(ATensorShape.end() - (outputSizes.size() - 2),
58+
ATensorShape.end(),
59+
AShapeBroadcasted.begin() + 2);
60+
std::vector<uint32_t> BShapeBroadcasted(outputSizes.begin(), outputSizes.end());
61+
std::copy(BTensorShape.end() - (outputSizes.size() - 2),
62+
BTensorShape.end(),
63+
BShapeBroadcasted.begin() + 2);
4864

4965
// output edges between DynQL and MMItoFloat node
5066
TensorDesc intermediateQuantizedATensorDesc = TensorDesc(
5167
BDatatype,
52-
gsl::make_span(ATensorShape),
68+
gsl::make_span(AShapeBroadcasted),
5369
gsl::make_span(ATensorShape),
5470
TensorAxis::DoNotCoerce,
5571
TensorAxis::W,
@@ -80,6 +96,30 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
8096
0 // guaranteedBaseOffsetAlignment
8197
);
8298

99+
TensorDesc broadcastedATensorDesc = TensorDesc(
100+
ADatatype,
101+
AShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting).
102+
ATensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'.
103+
TensorAxis::DoNotCoerce,
104+
TensorAxis::W,
105+
TensorAxis::RightAligned,
106+
NchwDimensionCount, // minDimensionCount
107+
0 // guaranteedBaseOffsetAlignment
108+
);
109+
110+
TensorDesc broadcastedBTensorDesc = TensorDesc(
111+
BDatatype,
112+
BShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting).
113+
BTensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'.
114+
TensorAxis::DoNotCoerce,
115+
TensorAxis::W,
116+
TensorAxis::RightAligned,
117+
NchwDimensionCount, // minDimensionCount
118+
0 // guaranteedBaseOffsetAlignment
119+
);
120+
121+
DML_TENSOR_DESC namedBroadcastedATensorDesc = broadcastedATensorDesc.GetDmlDesc();
122+
DML_TENSOR_DESC namedBroadcastedBTensorDesc = broadcastedBTensorDesc.GetDmlDesc();
83123
DML_TENSOR_DESC namedIntermediateQuantizedATensorDesc = intermediateQuantizedATensorDesc.GetDmlDesc();
84124
DML_TENSOR_DESC namedIntermediateQuantizedAScaleTensorDesc = intermediateQuantizedAScaleTensorDesc.GetDmlDesc();
85125
DML_TENSOR_DESC namedIntermediateQuantizedAZeroPointTensorDesc = intermediateQuantizedAZeroPointTensorDesc.GetDmlDesc();
@@ -88,7 +128,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
88128
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
89129

90130
DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC dynamicQuantizeLinearOperatorDesc = {};
91-
dynamicQuantizeLinearOperatorDesc.InputTensor = &inputDescs[OnnxInputIndex::A];
131+
dynamicQuantizeLinearOperatorDesc.InputTensor = &namedBroadcastedATensorDesc;
92132
dynamicQuantizeLinearOperatorDesc.OutputTensor = &namedIntermediateQuantizedATensorDesc;
93133
dynamicQuantizeLinearOperatorDesc.OutputScaleTensor = &namedIntermediateQuantizedAScaleTensorDesc;
94134
dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor = &namedIntermediateQuantizedAZeroPointTensorDesc;
@@ -99,7 +139,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
99139
matrixMultiplyIntergerToFloatOperatorDesc.ATensor = dynamicQuantizeLinearOperatorDesc.OutputTensor;
100140
matrixMultiplyIntergerToFloatOperatorDesc.AScaleTensor = dynamicQuantizeLinearOperatorDesc.OutputScaleTensor;
101141
matrixMultiplyIntergerToFloatOperatorDesc.AZeroPointTensor = dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor;
102-
matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &inputDescs[OnnxInputIndex::B];
142+
matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &namedBroadcastedBTensorDesc;
103143
matrixMultiplyIntergerToFloatOperatorDesc.BScaleTensor = &inputDescs[OnnxInputIndex::B_scale];
104144
matrixMultiplyIntergerToFloatOperatorDesc.BZeroPointTensor = hasBZP? &inputDescs[OnnxInputIndex::B_zero_point] : nullptr;
105145
matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr;

0 commit comments

Comments
 (0)