2929#include < optional>
3030#include < random>
3131
32+ #include " mlir/Dialect/Tosa/Utils/QuantUtils.h"
33+
3234using namespace mlir ;
3335using namespace mlir ::torch;
3436using namespace mlir ::torch::Torch;
@@ -2295,7 +2297,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
22952297 auto weightTy = cast<RankedTensorType>(weight.getType ());
22962298 auto outputTy =
22972299 cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType ()));
2298-
22992300 if (!inputTy || !weightTy || !outputTy)
23002301 return rewriter.notifyMatchFailure (
23012302 op, " Input, weight and output to Convolution must be ranked tensors" );
@@ -2304,6 +2305,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23042305 auto weightElemTy = weightTy.getElementType ();
23052306 auto inputShape = makeShapeTorchCompatible (inputTy.getShape ());
23062307 auto weightShape = makeShapeTorchCompatible (weightTy.getShape ());
2308+ auto outputElemTy = outputTy.getElementType ();
23072309
23082310 if (inputTy.getRank () != 4 )
23092311 return rewriter.notifyMatchFailure (
@@ -2316,28 +2318,21 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23162318 // Bias is optional. TOSA mandates a zero tensor here, so construct one if
23172319 // required.
23182320 auto bias = adaptor.getBias ();
2319- if (isa<Torch::NoneType>(adaptor.getBias ().getType ())) {
2320- // TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
2321- // accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
2322- // define a 48-bit int.
2323- if (isa<quant::QuantizedType>(inputElemTy)) {
2324- SmallVector<int32_t > zeroVec (weightShape[0 ], 0 );
2325- bias = tosa::getConstTensor<int32_t >(
2326- rewriter, op, zeroVec, {static_cast <int32_t >(weightShape[0 ])})
2327- .value ();
2328- } else {
2329- SmallVector<float > zeroVec (weightShape[0 ], 0 );
2330- bias = tosa::getConstTensor<float >(rewriter, op, zeroVec,
2331- {static_cast <int32_t >(weightShape[0 ])})
2332- .value ();
2333- }
2321+
2322+ if (isa<Torch::NoneType>(bias.getType ())) {
2323+ auto bias_result = tosa::getConvBiasForNoneType (op, rewriter, inputElemTy,
2324+ outputElemTy, weightShape);
2325+ if (failed (bias_result))
2326+ return rewriter.notifyMatchFailure (
2327+ op, " Failed to create bias tensor for none type." );
2328+ bias = bias_result.value ();
23342329 } else {
2335- if (!cast <RankedTensorType>(bias.getType ()))
2330+ if (!isa <RankedTensorType>(bias.getType ()))
23362331 return rewriter.notifyMatchFailure (
23372332 op, " Bias provided but not a ranked tensor" );
23382333 }
2339- auto biasElemTy =
2340- isa<mlir::FloatType>(inputElemTy) ? inputElemTy : rewriter. getI32Type ();
2334+
2335+ Type biasElemTy = cast<RankedTensorType>(bias. getType ()). getElementType ();
23412336
23422337 int64_t groups;
23432338 if (!matchPattern (op.getGroups (), m_TorchConstantInt (&groups))) {
@@ -2528,14 +2523,29 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25282523 auto convOpTy =
25292524 RankedTensorType::get (makeShapeLLVMCompatible (outputShape), biasElemTy);
25302525
2526+ // create zero-point tensors for input and weight
2527+ auto zps = tosa::createZPsAsConst (rewriter, input, weight);
2528+ // for i8 input/weight, zero-points are returned as un-initialized
2529+ Value inputZp =
2530+ zps.first
2531+ ? zps.first
2532+ : tosa::createZeroPointTensor (rewriter, op->getLoc (), inputElemTy, 0 )
2533+ .value ();
2534+
2535+ Value weightZp =
2536+ zps.second
2537+ ? zps.second
2538+ : tosa::createZeroPointTensor (rewriter, op->getLoc (), weightElemTy, 0 )
2539+ .value ();
2540+
25312541 Value convOpResult;
25322542 if (groups == 1 ) {
25332543 // full convolution
25342544 convOpResult =
25352545 rewriter
25362546 .create <tosa::Conv2DOp>(
25372547 op->getLoc (), getTypeConverter ()->convertType (convOpTy),
2538- transposedInput, transformedWeight, bias,
2548+ transposedInput, transformedWeight, bias, inputZp, weightZp,
25392549 rewriter.getDenseI64ArrayAttr (padding),
25402550 rewriter.getDenseI64ArrayAttr (stride),
25412551 rewriter.getDenseI64ArrayAttr (dilation), accType)
@@ -2546,7 +2556,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25462556 rewriter
25472557 .create <tosa::DepthwiseConv2DOp>(
25482558 op->getLoc (), getTypeConverter ()->convertType (convOpTy),
2549- transposedInput, transformedWeight, bias,
2559+ transposedInput, transformedWeight, bias, inputZp, weightZp,
25502560 rewriter.getDenseI64ArrayAttr (padding),
25512561 rewriter.getDenseI64ArrayAttr (stride),
25522562 rewriter.getDenseI64ArrayAttr (dilation), accType)
@@ -2574,8 +2584,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25742584 rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
25752585 }
25762586
2577- rewriter.replaceOpWithNewOp <tensor::CastOp>(
2578- op, getTypeConverter ()->convertType (op.getType ()), rescaledResult);
2587+ // cast to outputTy is required if convOpTy is not same as outputTy
2588+ // the difference is not in the shape information, rather the element-type
2589+ // itself
2590+ rewriter.replaceOp (
2591+ op,
2592+ {tosa::tosaCastTensorToType (rewriter, rescaledResult, outputTy).value ()});
25792593
25802594 return success ();
25812595}
0 commit comments