11#include < limits>
22
3- #include " torch/csrc/jit/ir/ir.h"
4- // #include "torch/csrc/jit/ir/constants.h"
53#include " ATen/core/List.h"
64#include " ATen/core/functional.h"
75#include " ATen/core/ivalue.h"
86#include " ATen/core/stack.h"
97#include " c10/util/intrusive_ptr.h"
8+ #include " torch/csrc/jit/ir/ir.h"
109#include " torch/torch.h"
1110
1211#include " core/conversion/evaluators/eval_macros.h"
@@ -24,28 +23,28 @@ auto prim_registrations =
2423 RegisterNodeEvaluators ()
2524 .evaluator(
2625 {torch::jit::prim::Constant,
27- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
26+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
2827 if (n->output ()->type ()->kind () == at::FunctionType::Kind) {
2928 return {};
3029 }
3130 return evaluators::toIValue (n->output ());
3231 }})
3332 .evaluator(
3433 {torch::jit::prim::NumToTensor,
35- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
34+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
3635 return evaluators::scalar_to_tensor (args.at (n->input (0 )).IValue ()->toScalar ());
3736 }})
3837 .evaluator(
3938 {torch::jit::prim::ListUnpack,
40- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
39+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
4140 // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
4241 const torch::jit::IValue* outputs = args.at (n->input ()).IValue ();
4342 auto outputVec = outputs->toList ().vec ();
4443 return std::move (c10::ivalue::Tuple::create (outputVec));
4544 }})
4645 .evaluator(
4746 {torch::jit::prim::ListConstruct,
48- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
47+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
4948 const auto num_inputs = n->inputs ().size ();
5049 if (constTypesOnly (args)) {
5150 c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
@@ -89,9 +88,8 @@ auto prim_registrations =
8988 return c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
9089 }
9190 } else {
92- c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
93- c10::TypePtr elementType = lt->getElementType ();
94- auto list = c10::impl::GenericList (elementType);
91+ // List would be of IValues (with ITensors embedded in them)
92+ auto list = c10::impl::GenericList (c10::AnyType::get ());
9593 list.reserve (num_inputs);
9694 for (auto in : n->inputs ()) {
9795 if (args.at (in).isITensor ()) {
@@ -103,8 +101,27 @@ auto prim_registrations =
103101 if (args.at (in).IValue ()->isNone ()) {
104102 auto ival = torch::jit::IValue ();
105103 list.emplace_back (std::move (ival));
104+ } else if (args.at (in).IValue ()->isInt ()) {
105+ auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const (
106+ ctx, torch::tensor ({args.at (in).unwrapToInt ()}).to (torch::kI32 ));
107+ auto tensor_holder = TensorContainer ();
108+ tensor_holder.hold_tensor (itensor);
109+ auto ival = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
110+ list.emplace_back (std::move (ival));
111+ } else if (args.at (in).IValue ()->isDouble ()) {
112+ auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const (
113+ ctx, torch::tensor ({args.at (in).unwrapToDouble ()}).to (torch::kFloat ));
114+ auto tensor_holder = TensorContainer ();
115+ tensor_holder.hold_tensor (itensor);
116+ auto ival = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
117+ list.emplace_back (std::move (ival));
106118 } else {
107- list.emplace_back (std::move (args.at (in).unwrapToTensor ()));
119+ auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const (
120+ ctx, std::move (args.at (in).unwrapToTensor ()));
121+ auto tensor_holder = TensorContainer ();
122+ tensor_holder.hold_tensor (itensor);
123+ auto ival = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
124+ list.emplace_back (std::move (ival));
108125 }
109126 }
110127 }
@@ -113,7 +130,7 @@ auto prim_registrations =
113130 }})
114131 .evaluator(
115132 {c10::Symbol::fromQualString (" prim::dtype" ),
116- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
133+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
117134 auto input = args.at (n->input (0 ));
118135 if (input.isITensor ()) {
119136 auto trt_dtype = input.ITensor ()->getType ();
@@ -136,7 +153,7 @@ auto prim_registrations =
136153 })})
137154 .evaluator(
138155 {c10::Symbol::fromQualString (" prim::min" ),
139- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
156+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
140157 if (n->inputs ().size () == 1 ) {
141158 auto a = args.at (n->input (0 )).unwrapToIntList ();
142159 int64_t min = std::numeric_limits<int64_t >::max ();
@@ -198,7 +215,7 @@ auto prim_registrations =
198215 })})
199216 .evaluator(
200217 {c10::Symbol::fromQualString (" prim::max" ),
201- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
218+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
202219 if (n->inputs ().size () == 1 ) {
203220 auto a = args.at (n->input (0 )).unwrapToIntList ();
204221 int64_t max = std::numeric_limits<int64_t >::min ();
@@ -260,7 +277,7 @@ auto prim_registrations =
260277 })})
261278 .evaluator(
262279 {c10::Symbol::fromQualString (" prim::shape" ),
263- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
280+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
264281 LOG_WARNING (" There may be undefined behavior using dynamic shape and prim::shape" );
265282 auto tensor_var = args.at (n->input (0 ));
266283 if (tensor_var.isITensor ()) {
@@ -274,7 +291,7 @@ auto prim_registrations =
274291 EvalOptions ().validSchemas ({" prim::shape(Tensor a) -> (int[])" })})
275292 .evaluator(
276293 {torch::jit::prim::TupleConstruct,
277- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
294+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
278295 c10::IValue tuple = c10::ivalue::Tuple::create ();
279296 std::vector<c10::IValue> elems;
280297 for (auto in : n->inputs ()) {
@@ -292,7 +309,7 @@ auto prim_registrations =
292309 }})
293310 .evaluator(
294311 {torch::jit::prim::TupleIndex,
295- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
312+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
296313 // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
297314 auto tuple = args.at (n->input (0 )).IValue ()->toTuple ();
298315 int64_t idx = args.at (n->input (1 )).IValue ()->toInt ();
@@ -302,24 +319,24 @@ auto prim_registrations =
302319 EvalOptions ().validSchemas ({" prim::TupleIndex(Any tup, int i) -> (Any)" })})
303320 .evaluator(
304321 {torch::jit::prim::TupleUnpack,
305- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
322+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
306323 // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
307324 auto output = args.at (n->input ()).IValue ()->toTuple ();
308325 return c10::optional<torch::jit::IValue>(std::move (output));
309326 }})
310327 .evaluator(
311328 {c10::Symbol::fromQualString (" prim::unchecked_cast" ),
312- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
329+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
313330 return *(args.at (n->input (0 )).IValue ());
314331 }})
315332 .evaluator(
316333 {c10::Symbol::fromQualString (" prim::Uninitialized" ),
317- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
334+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318335 return c10::IValue::uninitialized ();
319336 }})
320337 .evaluator(
321338 {c10::Symbol::fromQualString (" prim::RaiseException" ),
322- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
339+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
323340 auto exception = args.at (n->input (0 )).IValue ();
324341 TORCHTRT_THROW_ERROR (" Error from TorchScript: " << *exception);
325342 return {};
@@ -328,4 +345,4 @@ auto prim_registrations =
328345} // namespace evaluators
329346} // namespace conversion
330347} // namespace core
331- } // namespace torch_tensorrt
348+ } // namespace torch_tensorrt
0 commit comments