@@ -40,16 +40,32 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
4040 kernelCreationContext.GetTensorShapeDescription ().GetOutputTensorShape (0 )
4141 );
4242 }
43+ MLOperatorTensorDataType ADatatype = kernelCreationContext.GetInputEdgeDescription (OnnxInputIndex::A).tensorDataType ;
4344 MLOperatorTensorDataType BDatatype = kernelCreationContext.GetInputEdgeDescription (OnnxInputIndex::B).tensorDataType ;
4445
46+ gsl::span<const uint32_t > outputSizes = m_outputTensorDescs[0 ].GetSizes ();
4547 std::vector<uint32_t > ATensorShape = kernelCreationContext.GetTensorShapeDescription ().GetInputTensorShape (OnnxInputIndex::A);
46- std::vector<uint32_t > ExpectedAScaleTensorShape = {1 , 1 , 1 , 1 };
47- std::vector<uint32_t > ExpectedAZeroPointTensorShape = {1 , 1 , 1 , 1 };
48+ std::vector<uint32_t > BTensorShape = kernelCreationContext.GetTensorShapeDescription ().GetInputTensorShape (OnnxInputIndex::B);
49+ std::vector<uint32_t > ExpectedAScaleTensorShape (outputSizes.size (), 1 );
50+ std::vector<uint32_t > ExpectedAZeroPointTensorShape (outputSizes.size (), 1 );
51+ ML_CHECK_VALID_ARGUMENT (outputSizes.size () >= 4 );
52+ ML_CHECK_VALID_ARGUMENT (ATensorShape.size () >= 2 );
53+ ML_CHECK_VALID_ARGUMENT (BTensorShape.size () >= 2 );
54+ ML_CHECK_VALID_ARGUMENT (ATensorShape.size () + 2 >= outputSizes.size ());
55+ ML_CHECK_VALID_ARGUMENT (BTensorShape.size () + 2 >= outputSizes.size ());
56+ std::vector<uint32_t > AShapeBroadcasted (outputSizes.begin (), outputSizes.end ());
57+ std::copy (ATensorShape.end () - (outputSizes.size () - 2 ),
58+ ATensorShape.end (),
59+ AShapeBroadcasted.begin () + 2 );
60+ std::vector<uint32_t > BShapeBroadcasted (outputSizes.begin (), outputSizes.end ());
61+ std::copy (BTensorShape.end () - (outputSizes.size () - 2 ),
62+ BTensorShape.end (),
63+ BShapeBroadcasted.begin () + 2 );
4864
4965 // output edges between DynQL and MMItoFloat node
5066 TensorDesc intermediateQuantizedATensorDesc = TensorDesc (
5167 BDatatype,
52- gsl::make_span (ATensorShape ),
68+ gsl::make_span (AShapeBroadcasted ),
5369 gsl::make_span (ATensorShape),
5470 TensorAxis::DoNotCoerce,
5571 TensorAxis::W,
@@ -80,6 +96,30 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
8096 0 // guaranteedBaseOffsetAlignment
8197 );
8298
99+ TensorDesc broadcastedATensorDesc = TensorDesc (
100+ ADatatype,
101+ AShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting).
102+ ATensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'.
103+ TensorAxis::DoNotCoerce,
104+ TensorAxis::W,
105+ TensorAxis::RightAligned,
106+ NchwDimensionCount, // minDimensionCount
107+ 0 // guaranteedBaseOffsetAlignment
108+ );
109+
110+ TensorDesc broadcastedBTensorDesc = TensorDesc (
111+ BDatatype,
112+ BShapeBroadcasted, // Desired dimensions of tensor (after any broadcasting).
113+ BTensorShape, // Original dimensions (before any broadcasting). Usually same as 'dimensions'.
114+ TensorAxis::DoNotCoerce,
115+ TensorAxis::W,
116+ TensorAxis::RightAligned,
117+ NchwDimensionCount, // minDimensionCount
118+ 0 // guaranteedBaseOffsetAlignment
119+ );
120+
121+ DML_TENSOR_DESC namedBroadcastedATensorDesc = broadcastedATensorDesc.GetDmlDesc ();
122+ DML_TENSOR_DESC namedBroadcastedBTensorDesc = broadcastedBTensorDesc.GetDmlDesc ();
83123 DML_TENSOR_DESC namedIntermediateQuantizedATensorDesc = intermediateQuantizedATensorDesc.GetDmlDesc ();
84124 DML_TENSOR_DESC namedIntermediateQuantizedAScaleTensorDesc = intermediateQuantizedAScaleTensorDesc.GetDmlDesc ();
85125 DML_TENSOR_DESC namedIntermediateQuantizedAZeroPointTensorDesc = intermediateQuantizedAZeroPointTensorDesc.GetDmlDesc ();
@@ -88,7 +128,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
88128 std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs ();
89129
90130 DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC dynamicQuantizeLinearOperatorDesc = {};
91- dynamicQuantizeLinearOperatorDesc.InputTensor = &inputDescs[OnnxInputIndex::A] ;
131+ dynamicQuantizeLinearOperatorDesc.InputTensor = &namedBroadcastedATensorDesc ;
92132 dynamicQuantizeLinearOperatorDesc.OutputTensor = &namedIntermediateQuantizedATensorDesc;
93133 dynamicQuantizeLinearOperatorDesc.OutputScaleTensor = &namedIntermediateQuantizedAScaleTensorDesc;
94134 dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor = &namedIntermediateQuantizedAZeroPointTensorDesc;
@@ -99,7 +139,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
99139 matrixMultiplyIntergerToFloatOperatorDesc.ATensor = dynamicQuantizeLinearOperatorDesc.OutputTensor ;
100140 matrixMultiplyIntergerToFloatOperatorDesc.AScaleTensor = dynamicQuantizeLinearOperatorDesc.OutputScaleTensor ;
101141 matrixMultiplyIntergerToFloatOperatorDesc.AZeroPointTensor = dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor ;
102- matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &inputDescs[OnnxInputIndex::B] ;
142+ matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &namedBroadcastedBTensorDesc ;
103143 matrixMultiplyIntergerToFloatOperatorDesc.BScaleTensor = &inputDescs[OnnxInputIndex::B_scale];
104144 matrixMultiplyIntergerToFloatOperatorDesc.BZeroPointTensor = hasBZP? &inputDescs[OnnxInputIndex::B_zero_point] : nullptr ;
105145 matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr ;
0 commit comments