@@ -53,23 +53,66 @@ void create_plugin(
5353 LOG_DEBUG (" Normalize layer output tensor shape: " << layer_output->getDimensions ());
5454}
5555
56- auto normalize_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
57- {" aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)" ,
58- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
59- auto in = args[0 ].ITensor ();
60- auto in_shape = util::toVec (in->getDimensions ());
61- auto order = args[1 ].unwrapToScalar ().to <int32_t >();
62- auto axes_values = args[2 ].unwrapToIntList ().vec ();
63- std::vector<int32_t > axes (axes_values.begin (), axes_values.end ());
64- auto keep_dims = (int32_t )args[3 ].unwrapToBool ();
65- LOG_DEBUG (" Order of normalize_plugin: " << order);
66- LOG_DEBUG (" Axis: " << axes);
67- LOG_DEBUG (" keep_dims: " << keep_dims);
68- create_plugin (ctx, n, in, order, axes, keep_dims, " NormalizePluginTorchTRT" );
69- return true ;
70- }
71-
72- });
56+ auto normalize_registrations TORCHTRT_UNUSED =
57+ RegisterNodeConversionPatterns ()
58+ .pattern(
59+ {" aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)" ,
60+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
61+ auto in = args[0 ].ITensorOrFreeze (ctx);
62+ auto in_shape = util::toVec (in->getDimensions ());
63+ auto order = args[1 ].unwrapToScalar ().to <int32_t >();
64+ auto axes_values = args[2 ].unwrapToIntList ().vec ();
65+ std::vector<int32_t > axes (axes_values.begin (), axes_values.end ());
66+ auto keep_dims = (int32_t )args[3 ].unwrapToBool ();
67+ LOG_DEBUG (" Order of normalize_plugin: " << order);
68+ LOG_DEBUG (" Axis: " << axes);
69+ LOG_DEBUG (" keep_dims: " << keep_dims);
70+ create_plugin (ctx, n, in, order, axes, keep_dims, " NormalizePluginTorchTRT" );
71+ return true ;
72+ }
73+
74+ })
75+ .pattern(
76+ {" aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> (Tensor)" ,
77+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
78+ auto self = args[0 ].ITensorOrFreeze (ctx);
79+ auto axes_values = args[1 ].unwrapToIntList ().vec ();
80+ auto keep_dims = args[2 ].unwrapToBool ();
81+
82+ int32_t axes_mask = 0 ;
83+ auto self_nb_dims = self->getDimensions ().nbDims ;
84+ for (size_t i = 0UL ; i < axes_values.size (); ++i) {
85+ auto axis = axes_values[i];
86+ if (axis < 0 ) {
87+ axis += self_nb_dims;
88+ }
89+ TORCHTRT_CHECK (
90+ axis < self_nb_dims,
91+ " aten::frobenius_norm axis: " << i << " with value: " << axis << " exceeds input rank" );
92+ axes_mask += 1 << axis;
93+ }
94+
95+ auto squared_layer = add_elementwise (
96+ ctx, nvinfer1::ElementWiseOperation::kPROD , self, self, util::node_info (n) + " _squared" );
97+ TORCHTRT_CHECK (squared_layer, " Unabled to create square layer from node: " << *n);
98+ auto squared_output = squared_layer->getOutput (0 );
99+
100+ auto sum_layer =
101+ ctx->net ->addReduce (*squared_output, nvinfer1::ReduceOperation::kSUM , axes_mask, keep_dims);
102+ TORCHTRT_CHECK (sum_layer, " Unable to create sum layer from node: " << *n);
103+ sum_layer->setName ((util::node_info (n) + " _sum" ).c_str ());
104+ auto sum_output = sum_layer->getOutput (0 );
105+
106+ auto sqrt_layer = ctx->net ->addUnary (*sum_output, nvinfer1::UnaryOperation::kSQRT );
107+ TORCHTRT_CHECK (sqrt_layer, " Unable to create sqrt layer from node: " << *n);
108+ sqrt_layer->setName ((util::node_info (n) + " _sqrt" ).c_str ());
109+ auto sqrt_output = sqrt_layer->getOutput (0 );
110+
111+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], sqrt_layer->getOutput (0 ));
112+
113+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
114+ return true ;
115+ }});
73116
74117} // namespace
75118} // namespace impl
0 commit comments