@@ -16,127 +16,63 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
1616 {" aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)" ,
1717 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1818 auto in = args[0 ].ITensor ();
19- auto inDims = in->getDimensions ();
20- int64_t inRank = inDims .nbDims ;
19+ auto in_dims = in->getDimensions ();
20+ int64_t in_rank = in_dims .nbDims ;
2121 auto padding = args[1 ].unwrapToIntList ().vec ();
22- int64_t padSize = padding.size ();
22+ int64_t pad_size = padding.size ();
2323 auto value = args[2 ].unwrapToScalar ().to <float >();
2424 at::Tensor value_tensor = torch::tensor (value, util::TRTDataTypeToScalarType (in->getType ()));
25- auto valueTensor = tensor_to_const (ctx, value_tensor);
26- TORCHTRT_CHECK (padSize % 2 == 0 , " Length of pad must be even but instead it equals " << padSize);
27-
28- int64_t l_pad = padSize / 2 ;
29- TORCHTRT_CHECK (
30- inRank >= (int64_t )l_pad,
31- " Length of pad should be no more than twice the number of "
32- " dimensions of the input. Pad length is "
33- << padSize << " while the input has " << inRank << " dimensions." );
34-
35- // TODO negative padding. When the pad is negative, we need to crop the image.
36-
37- std::vector<nvinfer1::ITensor*> tensors_vec;
38- // input: (N, C, D_in, H_in, W_in).
39- // padding: (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
40- // When axis is inRank - 1, making W_out = W_in + padding_left + padding_right.
41- // When axis is inRank - 2, making H_out = H_in + padding_top + padding_bottom.
42- // When axis is inRank - 3, making D_out = D_in + padding_front + padding_back.
43- for (int64_t i = 0 ; i < l_pad; i++) {
44- int64_t axis = inRank - (i + 1 ); // axis = {inRank - 1, inRank - 2, inRank - 3}
45- int64_t padding_index = i * 2 ;
46-
47- if (padding[padding_index] > 0 ) { // left/top/front padding value
48- tensors_vec.clear ();
49- if (ctx->input_is_dynamic ) {
50- at::Tensor left_indices = torch::tensor ({0 }, torch::kInt32 );
51- auto indicesTensor = tensor_to_const (ctx, left_indices);
52- auto left_gather_layer = ctx->net ->addGather (*in, *indicesTensor, axis);
53- auto left_gather_out = left_gather_layer->getOutput (0 );
54-
55- // fill the left_gather_out with value
56- auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
57- auto shape_gather_out = ctx->net ->addShape (*left_gather_out)->getOutput (0 );
58- fill_layer->setInput (0 , *shape_gather_out);
59- fill_layer->setInput (1 , *valueTensor);
60- at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in->getType ()));
61- auto deltaTensor = tensor_to_const (ctx, delta_tensor);
62- fill_layer->setInput (2 , *deltaTensor);
63- auto padTensor = fill_layer->getOutput (0 );
64-
65- for (int i = 0 ; i < padding[padding_index]; i++) {
66- tensors_vec.push_back (padTensor);
67- }
68- } else {
69- inDims.d [axis] = padding[padding_index];
70- auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
71- fill_layer->setInput (1 , *valueTensor);
72- at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in->getType ()));
73- auto deltaTensor = tensor_to_const (ctx, delta_tensor);
74- fill_layer->setInput (2 , *deltaTensor);
75- auto padTensor = fill_layer->getOutput (0 );
76-
77- tensors_vec.push_back (padTensor);
78- }
25+ auto value_itensor = tensor_to_const (ctx, value_tensor);
26+ TORCHTRT_CHECK (pad_size % 2 == 0 , " Length of pad must be even but instead it equals " << pad_size);
27+
28+ std::vector<int64_t > start (in_rank, 0 );
29+ std::vector<int64_t > total_padding (in_rank, 0 );
30+ std::vector<int64_t > stride (in_rank, 1 );
31+
32+ // Padding is stored (left, right) starting from the last dim and working backwards
33+ for (size_t i = 0UL ; i < padding.size (); i += 2 ) {
34+ auto left = padding[i];
35+ TORCHTRT_CHECK (left >= 0 , " Unsupported negative pad at index " << i);
36+ auto right = padding[i + 1 ];
37+ TORCHTRT_CHECK (right >= 0 , " Unsupported negative pad at index " << i + 1 );
38+ auto idx = in_rank - ((i / 2 ) + 1 );
39+ start[idx] = -left;
40+ total_padding[idx] = left + right;
41+ }
7942
80- tensors_vec. push_back (in);
81- auto concat_layer = ctx->net -> addConcatenation (tensors_vec. data (), tensors_vec. size ());
82- concat_layer-> setAxis (axis) ;
83- in = concat_layer-> getOutput ( 0 );
84- inDims = in-> getDimensions () ;
43+ auto size = stride; // placeholder for the dynamic case
44+ if (! ctx->input_is_dynamic ) {
45+ size = total_padding ;
46+ for ( size_t i = 0UL ; i < total_padding. size (); ++i) {
47+ size[i] += in_dims. d [i] ;
8548 }
49+ }
8650
87- if (padding[padding_index + 1 ] > 0 ) { // right/bottom/back padding value
88- tensors_vec.clear ();
89- tensors_vec.push_back (in);
90-
91- nvinfer1::ITensor* indicesTensor = NULL ;
92- if (inDims.d [axis] == -1 ) {
93- auto shapeTensor = ctx->net ->addShape (*in)->getOutput (0 );
94- at::Tensor dimValue = torch::tensor ({axis}, torch::kInt32 );
95- auto dimTensor = tensor_to_const (ctx, dimValue);
96- indicesTensor = ctx->net ->addGather (*shapeTensor, *dimTensor, 0 )->getOutput (0 );
97- auto oneTensor = tensor_to_const (ctx, torch::tensor ({1 }, torch::kInt32 ));
98- indicesTensor = ctx->net ->addElementWise (*indicesTensor, *oneTensor, nvinfer1::ElementWiseOperation::kSUB )
99- ->getOutput (0 );
100- } else {
101- auto indices = torch::tensor ({inDims.d [axis] - 1 }, torch::kInt32 );
102- indicesTensor = tensor_to_const (ctx, indices);
103- }
104- auto right_gather_layer = ctx->net ->addGather (*in, *indicesTensor, axis);
105- auto right_gather_out = right_gather_layer->getOutput (0 );
106-
107- if (ctx->input_is_dynamic ) {
108- // fill the right_gather_out with value
109- auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
110- auto shape_gather_out = ctx->net ->addShape (*right_gather_out)->getOutput (0 );
111- fill_layer->setInput (0 , *shape_gather_out);
112- fill_layer->setInput (1 , *valueTensor);
113- at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in->getType ()));
114- auto deltaTensor = tensor_to_const (ctx, delta_tensor);
115- fill_layer->setInput (2 , *deltaTensor);
116- auto padTensor = fill_layer->getOutput (0 );
117-
118- for (int i = 0 ; i < padding[padding_index + 1 ]; i++) {
119- tensors_vec.push_back (padTensor);
120- }
121- } else {
122- inDims.d [axis] = padding[padding_index + 1 ];
123- auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
124- fill_layer->setInput (1 , *valueTensor);
125- at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in->getType ()));
126- auto deltaTensor = tensor_to_const (ctx, delta_tensor);
127- fill_layer->setInput (2 , *deltaTensor);
128- auto padTensor = fill_layer->getOutput (0 );
129-
130- tensors_vec.push_back (padTensor);
131- }
132- auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
133- concat_layer->setAxis (axis);
134- in = concat_layer->getOutput (0 );
135- inDims = in->getDimensions ();
136- }
51+ auto slice_layer = ctx->net ->addSlice (
52+ *in,
53+ util::toDims (c10::IntArrayRef (start)),
54+ util::toDims (c10::IntArrayRef (size)),
55+ util::toDims (c10::IntArrayRef (stride)));
56+ TORCHTRT_CHECK (slice_layer, " Unable to create slice layer from node: " << *n);
57+ slice_layer->setName ((util::node_info (n) + " _slice" ).c_str ());
58+ slice_layer->setMode (nvinfer1::SliceMode::kFILL );
59+ slice_layer->setInput (4 , *value_itensor);
60+
61+ if (ctx->input_is_dynamic ) {
62+ // build the size using inetwork layers
63+ auto shape_layer = ctx->net ->addShape (*in);
64+ TORCHTRT_CHECK (shape_layer, " Unable to create shape layer from node: " << *n);
65+ shape_layer->setName ((util::node_info (n) + " _shape" ).c_str ());
66+ auto total_padding_itensor = tensor_to_const (ctx, torch::tensor (total_padding, torch::kInt32 ));
67+
68+ auto add_layer = ctx->net ->addElementWise (
69+ *shape_layer->getOutput (0 ), *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM );
70+ TORCHTRT_CHECK (add_layer, " Unable to create add layer from node: " << *n);
71+ add_layer->setName ((util::node_info (n) + " _add" ).c_str ());
72+ slice_layer->setInput (2 , *add_layer->getOutput (0 ));
13773 }
13874
139- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in );
75+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], slice_layer-> getOutput ( 0 ) );
14076 LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
14177 return true ;
14278 }});
0 commit comments