@@ -27,78 +27,14 @@ nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfe
2727 }
2828}
2929
30- bool add_expand (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) {
31- auto input_dims = in->getDimensions ();
32- TORCHTRT_CHECK (
33- input_dims.nbDims <= expandedDims.nbDims ,
34- " Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions" );
35-
36- // Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
37- for (int64_t i = expandedDims.nbDims - 1 ; i >= 0 ; --i) {
38- int64_t offset = expandedDims.nbDims - 1 - i;
39- int64_t dim = input_dims.nbDims - 1 - offset;
40- int64_t size = (dim >= 0 ) ? input_dims.d [dim] : 1 ;
41- int64_t targetSize = expandedDims.d [i];
42- // In expand layer passing -1 as the size for a dimension means not changing the size of that dimension.
43- if (targetSize != -1 ) {
44- if (size != targetSize) {
45- if (size != 1 ) {
46- TORCHTRT_THROW_ERROR (
47- " The expanded size of tensor (" << targetSize << " )"
48- << " must match the existing size (" << size << " )"
49- << " at dimension " << i);
50- }
51- }
52- } else {
53- // For the new dimensions, the size cannot be set to -1. Eg: an input of [3, 1] can be expanded to [3, -1, 4] but
54- // not [-1, 3, 4].
55- if (dim < 0 ) {
56- TORCHTRT_THROW_ERROR (
57- " The expanded size of the tensor (" << targetSize << " ) isn't allowed in a leading, non-existing dimension "
58- << i);
59- } else {
60- // in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4)
61- expandedDims.d [i] = input_dims.d [dim];
62- }
63- }
64- }
65-
66- auto num_expand_dims = expandedDims.nbDims - input_dims.nbDims ;
67- if (num_expand_dims > 0 ) {
68- nvinfer1::Dims reshape_dims;
69- reshape_dims.nbDims = expandedDims.nbDims ;
70- for (int64_t i = 0 ; i < num_expand_dims; i++) {
71- reshape_dims.d [i] = 1 ;
72- }
73- for (int64_t i = 0 ; i < input_dims.nbDims ; i++) {
74- reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
75- }
76- // Add a reshape layer to expand dims
77- auto reshape_layer = ctx->net ->addShuffle (*in);
78- reshape_layer->setReshapeDimensions (reshape_dims);
79- in = reshape_layer->getOutput (0 );
80- LOG_DEBUG (" Input reshaped to : " << in->getDimensions () << " from " << input_dims);
81- }
82-
83- // Start the slicing from beginning of tensor since this is an expand layer
84- std::vector<int64_t > start_vec (expandedDims.nbDims , 0 );
85- auto start_offset = util::toDims (c10::IntArrayRef (start_vec));
86-
87- // Set the stride of non singleton dimension to 1
88- std::vector<int64_t > strides_vec (expandedDims.nbDims , 0 );
89- for (int64_t i = 0 ; i < expandedDims.nbDims ; i++) {
90- strides_vec[i] = (in->getDimensions ().d [i] != 1 );
91- }
92-
93- auto strides = util::toDims (c10::IntArrayRef (strides_vec));
94- // Slice layer does the expansion in TRT. Desired output size is specified by expandedDims
95- auto slice_layer = ctx->net ->addSlice (*in, start_offset, expandedDims, strides);
96- slice_layer->setName (util::node_info (n).c_str ());
97-
98- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], slice_layer->getOutput (0 ));
99-
30+ bool add_expand_static (
31+ ConversionCtx* ctx,
32+ const torch::jit::Node* n,
33+ nvinfer1::ITensor* in,
34+ nvinfer1::Dims expandedDims) {
35+ auto expand_out = add_expand (ctx, in, expandedDims);
36+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], expand_out);
10037 LOG_DEBUG (" Expand layer output tensor shape: " << out->getDimensions ());
101-
10238 return true ;
10339}
10440
@@ -209,7 +145,7 @@ auto expand_registrations TORCHTRT_UNUSED =
209145 auto expandedDimsTensor = tensor_to_const (ctx, thExpanded_size);
210146 return add_expand_dynamic (ctx, n, in, expandedDimsTensor, expandedDims, true );
211147 } else {
212- return add_expand (ctx, n, in, expandedDims);
148+ return add_expand_static (ctx, n, in, expandedDims);
213149 }
214150 }})
215151 .pattern(
@@ -223,7 +159,7 @@ auto expand_registrations TORCHTRT_UNUSED =
223159 if (ctx->input_is_dynamic ) {
224160 return add_expand_dynamic (ctx, n, in, getShapeOutput (ctx, targetTensor), targetDims, false );
225161 } else {
226- return add_expand (ctx, n, in, targetDims);
162+ return add_expand_static (ctx, n, in, targetDims);
227163 }
228164 }})
229165 .pattern(
0 commit comments