@@ -24,28 +24,28 @@ auto prim_registrations =
2424 RegisterNodeEvaluators ()
2525 .evaluator(
2626 {torch::jit::prim::Constant,
27- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
27+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
2828 if (n->output ()->type ()->kind () == at::FunctionType::Kind) {
2929 return {};
3030 }
3131 return evaluators::toIValue (n->output ());
3232 }})
3333 .evaluator(
3434 {torch::jit::prim::NumToTensor,
35- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
35+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
3636 return evaluators::scalar_to_tensor (args.at (n->input (0 )).IValue ()->toScalar ());
3737 }})
3838 .evaluator(
3939 {torch::jit::prim::ListUnpack,
40- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
40+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
4141 // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
4242 const torch::jit::IValue* outputs = args.at (n->input ()).IValue ();
4343 auto outputVec = outputs->toList ().vec ();
4444 return std::move (c10::ivalue::Tuple::create (outputVec));
4545 }})
4646 .evaluator(
4747 {torch::jit::prim::ListConstruct,
48- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
48+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
4949 const auto num_inputs = n->inputs ().size ();
5050 if (constTypesOnly (args)) {
5151 c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
@@ -103,8 +103,14 @@ auto prim_registrations =
103103 if (args.at (in).IValue ()->isNone ()) {
104104 auto ival = torch::jit::IValue ();
105105 list.emplace_back (std::move (ival));
106+ } else if (args.at (in).IValue ()->isInt ()) {
107+ auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const (ctx, torch::tensor (args.at (in).unwrapToInt ()));
108+ auto tensor_holder = TensorContainer ();
109+ tensor_holder.hold_tensor (itensor);
110+ auto ival = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
111+ list.emplace_back (std::move (ival));
106112 } else {
107- list.emplace_back (std::move (args.at (in).unwrapToTensor ()));
113+ list.emplace_back (std::move (args.at (in).unwrapToTensor ()));
108114 }
109115 }
110116 }
@@ -113,7 +119,7 @@ auto prim_registrations =
113119 }})
114120 .evaluator(
115121 {c10::Symbol::fromQualString (" prim::dtype" ),
116- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
122+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
117123 auto input = args.at (n->input (0 ));
118124 if (input.isITensor ()) {
119125 auto trt_dtype = input.ITensor ()->getType ();
@@ -136,7 +142,7 @@ auto prim_registrations =
136142 })})
137143 .evaluator(
138144 {c10::Symbol::fromQualString (" prim::min" ),
139- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
145+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
140146 if (n->inputs ().size () == 1 ) {
141147 auto a = args.at (n->input (0 )).unwrapToIntList ();
142148 int64_t min = std::numeric_limits<int64_t >::max ();
@@ -198,7 +204,7 @@ auto prim_registrations =
198204 })})
199205 .evaluator(
200206 {c10::Symbol::fromQualString (" prim::max" ),
201- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
207+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
202208 if (n->inputs ().size () == 1 ) {
203209 auto a = args.at (n->input (0 )).unwrapToIntList ();
204210 int64_t max = std::numeric_limits<int64_t >::min ();
@@ -260,7 +266,7 @@ auto prim_registrations =
260266 })})
261267 .evaluator(
262268 {c10::Symbol::fromQualString (" prim::shape" ),
263- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
269+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
264270 LOG_WARNING (" There may be undefined behavior using dynamic shape and prim::shape" );
265271 auto tensor_var = args.at (n->input (0 ));
266272 if (tensor_var.isITensor ()) {
@@ -274,7 +280,7 @@ auto prim_registrations =
274280 EvalOptions ().validSchemas ({" prim::shape(Tensor a) -> (int[])" })})
275281 .evaluator(
276282 {torch::jit::prim::TupleConstruct,
277- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
283+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
278284 c10::IValue tuple = c10::ivalue::Tuple::create ();
279285 std::vector<c10::IValue> elems;
280286 for (auto in : n->inputs ()) {
@@ -292,7 +298,7 @@ auto prim_registrations =
292298 }})
293299 .evaluator(
294300 {torch::jit::prim::TupleIndex,
295- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
301+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
296302 // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
297303 auto tuple = args.at (n->input (0 )).IValue ()->toTuple ();
298304 int64_t idx = args.at (n->input (1 )).IValue ()->toInt ();
@@ -302,24 +308,24 @@ auto prim_registrations =
302308 EvalOptions ().validSchemas ({" prim::TupleIndex(Any tup, int i) -> (Any)" })})
303309 .evaluator(
304310 {torch::jit::prim::TupleUnpack,
305- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
311+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
306312 // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
307313 auto output = args.at (n->input ()).IValue ()->toTuple ();
308314 return c10::optional<torch::jit::IValue>(std::move (output));
309315 }})
310316 .evaluator(
311317 {c10::Symbol::fromQualString (" prim::unchecked_cast" ),
312- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
313319 return *(args.at (n->input (0 )).IValue ());
314320 }})
315321 .evaluator(
316322 {c10::Symbol::fromQualString (" prim::Uninitialized" ),
317- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
323+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318324 return c10::IValue::uninitialized ();
319325 }})
320326 .evaluator(
321327 {c10::Symbol::fromQualString (" prim::RaiseException" ),
322- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
328+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
323329 auto exception = args.at (n->input (0 )).IValue ();
324330 TORCHTRT_THROW_ERROR (" Error from TorchScript: " << *exception);
325331 return {};
@@ -328,4 +334,4 @@ auto prim_registrations =
328334} // namespace evaluators
329335} // namespace conversion
330336} // namespace core
331- } // namespace torch_tensorrt
337+ } // namespace torch_tensorrt
0 commit comments