@@ -1070,7 +1070,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
10701070 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
10711071 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
10721072 llvm::SmallVector<int64_t > outputShape (5 , ShapedType::kDynamicSize );
1073- Conv2DOp ::Adaptor adaptor (operands.getValues (), attributes);
1073+ Conv3DOp ::Adaptor adaptor (operands.getValues (), attributes);
10741074
10751075 int32_t inputWidth = ShapedType::kDynamicSize ;
10761076 int32_t inputHeight = ShapedType::kDynamicSize ;
@@ -1084,55 +1084,54 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
10841084 ShapeAdaptor inputShape = operands.getShape (adaptor.getInput ());
10851085 if (inputShape.hasRank ()) {
10861086 outputShape[0 ] = inputShape.getDimSize (0 );
1087- inputHeight = inputShape.getDimSize (1 );
1088- inputWidth = inputShape.getDimSize (2 );
1089- inputDepth = inputShape.getDimSize (3 );
1087+ inputDepth = inputShape.getDimSize (1 );
1088+ inputHeight = inputShape.getDimSize (2 );
1089+ inputWidth = inputShape.getDimSize (3 );
10901090 }
10911091
10921092 // Weight shapes describes the filter width/height and the output channels.
10931093 ShapeAdaptor weightShape = operands.getShape (adaptor.getWeight ());
10941094 if (weightShape.hasRank ()) {
10951095 outputShape[4 ] = weightShape.getDimSize (0 );
1096- weightHeight = weightShape.getDimSize (1 );
1097- weightWidth = weightShape.getDimSize (2 );
1098- weightDepth = weightShape.getDimSize (3 );
1096+ weightDepth = weightShape.getDimSize (1 );
1097+ weightHeight = weightShape.getDimSize (2 );
1098+ weightWidth = weightShape.getDimSize (3 );
10991099 }
11001100
11011101 // Bias shape can describe the output channels.
11021102 ShapeAdaptor biasShape = operands.getShape (adaptor.getBias ());
1103- if (biasShape.hasRank ()) {
1104- outputShape[4 ] =
1105- (outputShape[4 ] == -1 ) ? biasShape.getDimSize (0 ) : outputShape[4 ];
1103+ if (biasShape.hasRank () && ShapedType::isDynamic (outputShape[4 ])) {
1104+ outputShape[4 ] = biasShape.getDimSize (0 );
11061105 }
11071106
11081107 llvm::SmallVector<int64_t > dilation;
1109- llvm::SmallVector<int64_t > padding ;
1108+ llvm::SmallVector<int64_t > pad ;
11101109 llvm::SmallVector<int64_t > stride;
11111110
11121111 getI64Values (adaptor.getDilation (), dilation);
1113- getI64Values (adaptor.getPad (), padding );
1112+ getI64Values (adaptor.getPad (), pad );
11141113 getI64Values (adaptor.getStride (), stride);
11151114
1116- if (!ShapedType::isDynamic (inputHeight ) &&
1117- !ShapedType::isDynamic (weightHeight )) {
1118- int32_t inputSize = inputHeight + padding [0 ] + padding [1 ];
1119- int32_t filterSize = (weightHeight - 1 ) * dilation[0 ] + 1 ;
1115+ if (!ShapedType::isDynamic (inputDepth ) &&
1116+ !ShapedType::isDynamic (weightDepth )) {
1117+ int32_t inputSize = inputDepth + pad [0 ] + pad [1 ];
1118+ int32_t filterSize = (weightDepth - 1 ) * dilation[0 ] + 1 ;
11201119 int32_t unstridedResult = inputSize - filterSize + 1 ;
11211120 outputShape[1 ] = (unstridedResult - 1 ) / stride[0 ] + 1 ;
11221121 }
11231122
1124- if (!ShapedType::isDynamic (inputWidth ) &&
1125- !ShapedType::isDynamic (weightWidth )) {
1126- int32_t inputSize = inputWidth + padding [2 ] + padding [3 ];
1127- int32_t filterSize = (weightWidth - 1 ) * dilation[1 ] + 1 ;
1123+ if (!ShapedType::isDynamic (inputHeight ) &&
1124+ !ShapedType::isDynamic (weightHeight )) {
1125+ int32_t inputSize = inputHeight + pad [2 ] + pad [3 ];
1126+ int32_t filterSize = (weightHeight - 1 ) * dilation[1 ] + 1 ;
11281127 int32_t unstridedResult = inputSize - filterSize + 1 ;
11291128 outputShape[2 ] = (unstridedResult - 1 ) / stride[1 ] + 1 ;
11301129 }
11311130
1132- if (!ShapedType::isDynamic (inputDepth ) &&
1133- !ShapedType::isDynamic (weightDepth )) {
1134- int32_t inputSize = inputDepth + padding [4 ] + padding [5 ];
1135- int32_t filterSize = (weightDepth - 1 ) * dilation[2 ] + 1 ;
1131+ if (!ShapedType::isDynamic (inputWidth ) &&
1132+ !ShapedType::isDynamic (weightWidth )) {
1133+ int32_t inputSize = inputWidth + pad [4 ] + pad [5 ];
1134+ int32_t filterSize = (weightWidth - 1 ) * dilation[2 ] + 1 ;
11361135 int32_t unstridedResult = inputSize - filterSize + 1 ;
11371136 outputShape[3 ] = (unstridedResult - 1 ) / stride[2 ] + 1 ;
11381137 }
0 commit comments