@@ -180,15 +180,50 @@ nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val, bool use
180180 return dims;
181181}
182182
183- nvinfer1::Dims squeezeDims (const nvinfer1::Dims& d, int pos, bool use_zeros) {
183+ int validateInputDimsForShuffle (const nvinfer1::Dims& d, bool input_is_dynamic) {
184+ int num_zeros_detected = 0 ;
185+
186+ // For each dimension, increment counter if that dimension has value 0
187+ for (int i = 0 ; i < d.nbDims ; i++) {
188+ if (d.d [i] == 0 ) {
189+ num_zeros_detected++;
190+ }
191+ }
192+
193+ // If the tensor from which the dimensions originate has dynamic shape and more than 1
194+ // zero dimension is detected, this constitutes an invalid shape to the TRT Shuffle Layer,
195+ // since dynamic dimensions to Shuffle Layers are generally represented with a 0
196+ // denoting to inherit the dimension from the input tensor, thus causing an
197+ // overload of the "0" dimension
198+ return (input_is_dynamic && num_zeros_detected > 1 ) ? -1 : num_zeros_detected;
199+ }
200+
201+ nvinfer1::Dims squeezeDims (const nvinfer1::Dims& d, int pos, bool use_zeros, bool swap_existing_zeros) {
184202 // acceptable range for pos is [0, d.nbDims]
185203 TORCHTRT_ASSERT (pos >= 0 && pos <= d.nbDims , " ERROR: Index to squeeze is out of bounds." );
186204
187205 nvinfer1::Dims dims;
188206 int j = 0 ;
189207 for (int i = 0 ; i < d.nbDims ; i++) {
190208 if (i != pos) {
191- dims.d [j++] = (use_zeros && d.d [i] == -1 ) ? 0 : d.d [i];
209+ // If zeros are replacing dynamic/existing dimensions,
210+ // Replace all instances of -1, indicating dynamic dimension
211+ // with 0, indicating copy the dimension from another tensor
212+ // (Generally used for reshape operations)
213+ if (use_zeros && d.d [i] == -1 ) {
214+ dims.d [j] = 0 ;
215+ // If zeros already exist in the dimensions (empty tensor),
216+ // Replace all instances of 0, indicating empty dimension
217+ // with -1, indicating inherit the dimension from reshape
218+ // (Generally used for reshape operations)
219+ } else if (swap_existing_zeros && d.d [i] == 0 ) {
220+ dims.d [j] = -1 ;
221+ // Otherwise, replace the dimension with the same value from the input
222+ } else {
223+ dims.d [j] = d.d [i];
224+ }
225+
226+ j++;
192227 }
193228 }
194229 dims.nbDims = j;
0 commit comments