@@ -257,3 +257,147 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
257257 int count = count_trt_engines (fallback_g);
258258 ASSERT_TRUE (count == 2 );
259259}
260+
261+ TEST (Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
262+ /* parseIR does not support "= aten::_set_item" so we will build this graph manually
263+ const auto graph = R"IR(
264+ graph(%x : Tensor,
265+ %y : Tensor):
266+ %2 : str = prim::Constant[value="INS"]()
267+ %3 : str = prim::Constant[value="OUTS"]()
268+ %4 : bool = prim::Constant[value=0]()
269+ %5 : int = prim::Constant[value=-1]()
270+ %6 : Dict(str, Tensor) = prim::DictConstruct()
271+ = aten::_set_item(%6, %2, %x)
272+ %7 : Tensor = aten::__getitem__(%6, %2)
273+ %8 : Tensor = aten::lt(%7, %y)
274+ %9 : Tensor?[] = prim::ListConstruct(%8)
275+ %10 : int = prim::dtype(%7)
276+ %11 : Device = prim::device(%7)
277+ %12 : Tensor = aten::tensor(%5, %10, %11, %4)
278+ %13 : Tensor = aten::index_put_(%7, %9, %12, %4)
279+ = aten::_set_item(%6, %3, %7)
280+ %14 : Tensor = aten::__getitem__(%6, %2)
281+ %15 : Tensor = aten::__getitem__(%6, %3)
282+ return (%14, %15))IR";
283+ */
284+ auto g = std::make_shared<torch::jit::Graph>();
285+ auto x = g->insertInput (0 , " x" );
286+ auto y = g->insertInput (1 , " y" );
287+ torch::jit::IValue ins_key (" INS" );
288+ auto ins_key_val = g->insertConstant (ins_key);
289+ torch::jit::IValue outs_key (" OUTS" );
290+ auto outs_key_val = g->insertConstant (outs_key);
291+ torch::jit::IValue zero (0 );
292+ auto false_const_val = g->insertConstant (zero);
293+ false_const_val->setType (c10::BoolType::get ());
294+ torch::jit::IValue neg_one (-1 );
295+ auto neg_one_const_val = g->insertConstant (neg_one);
296+ auto dict_node = g->createDict (
297+ ins_key_val->type (),
298+ x->type (),
299+ torch::jit::ArrayRef<torch::jit::Value*>(),
300+ torch::jit::ArrayRef<torch::jit::Value*>());
301+ g->insertNode (dict_node);
302+ auto set_node = g->create (
303+ torch::jit::Symbol::fromQualString (" aten::_set_item" ),
304+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val, x},
305+ 0 );
306+ g->insertNode (set_node);
307+ auto get_node = g->create (
308+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
309+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val},
310+ 1 );
311+ g->insertNode (get_node);
312+ auto lt_node = g->create (
313+ torch::jit::Symbol::fromQualString (" aten::lt" ),
314+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), y},
315+ 1 );
316+ g->insertNode (lt_node);
317+ auto list_node = g->createList (
318+ at::OptionalType::create (lt_node->output ()->type ()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output ()});
319+ g->insertNode (list_node);
320+ auto dtype_node = g->create (
321+ torch::jit::Symbol::fromQualString (" prim::dtype" ),
322+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()},
323+ 1 );
324+ dtype_node->output ()->setType (neg_one_const_val->type ());
325+ g->insertNode (dtype_node);
326+ auto device_node = g->create (
327+ torch::jit::Symbol::fromQualString (" prim::device" ),
328+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()},
329+ 1 );
330+ device_node->output ()->setType (c10::DeviceObjType::get ());
331+ g->insertNode (device_node);
332+ auto tensor_node = g->create (
333+ torch::jit::Symbol::fromQualString (" aten::tensor" ),
334+ torch::jit::ArrayRef<torch::jit::Value*>{
335+ neg_one_const_val, dtype_node->output (), device_node->output (), false_const_val},
336+ 1 );
337+ g->insertNode (tensor_node);
338+ auto index_put_node = g->create (
339+ torch::jit::Symbol::fromQualString (" aten::index_put_" ),
340+ torch::jit::ArrayRef<torch::jit::Value*>{
341+ get_node->output (), list_node->output (), tensor_node->output (), false_const_val},
342+ 1 );
343+ g->insertNode (index_put_node);
344+ auto out_set_node = g->create (
345+ torch::jit::Symbol::fromQualString (" aten::_set_item" ),
346+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val, get_node->output ()},
347+ 0 );
348+ g->insertNode (out_set_node);
349+ auto get_ins_node = g->create (
350+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
351+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val},
352+ 1 );
353+ g->insertNode (get_ins_node);
354+ auto get_outs_node = g->create (
355+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
356+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val},
357+ 1 );
358+ g->insertNode (get_outs_node);
359+ g->registerOutput (get_ins_node->output ());
360+ g->registerOutput (get_outs_node->output ());
361+
362+ torch_tensorrt::core::partitioning::PartitionInfo partition_info;
363+ partition_info.enabled = true ;
364+ std::vector<torch_tensorrt::core::ir::Input> inputs;
365+ inputs.push_back (torch_tensorrt::core::ir::Input ({4 , 4 }));
366+ inputs.push_back (torch_tensorrt::core::ir::Input ({4 , 4 }));
367+
368+ std::unordered_map<const torch::jit::Value*, torch_tensorrt::core::ir::Input> inputs_map;
369+ std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> input_types;
370+ for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
371+ inputs_map.insert ({g->inputs ()[i], inputs[i]});
372+ input_types.insert ({g->inputs ()[i], {at::kFloat }});
373+ }
374+ auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs (inputs_map, input_types);
375+ auto segmented_blocks = torch_tensorrt::core::partitioning::Partition (g->block (), input_ivalues_map, partition_info);
376+
377+ int torch_block_cnt = 0 , trt_block_cnt = 0 ;
378+ for (const auto & segmented_block : segmented_blocks) {
379+ if (segmented_block.target () == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT ) {
380+ ++trt_block_cnt;
381+ ASSERT_TRUE (checkSegmentedBlockInputType (segmented_block, [](torch::jit::TypePtr type_ptr) {
382+ return type_ptr->isSubtypeOf (torch::jit::TensorType::get ());
383+ }));
384+ } else {
385+ ++torch_block_cnt;
386+ bool output_dict = false ;
387+ bool input_dict = false ;
388+ auto dict_type = dict_node->output ()->type ();
389+ for (auto in : segmented_block.raw_inputs ()) {
390+ if (in->type ()->isSubtypeOf (dict_type)) {
391+ input_dict = true ;
392+ }
393+ }
394+ for (auto out : segmented_block.raw_outputs ()) {
395+ if (out->type ()->isSubtypeOf (dict_type)) {
396+ output_dict = true ;
397+ }
398+ }
399+ EXPECT_TRUE (output_dict ^ input_dict);
400+ }
401+ }
402+ ASSERT_TRUE (trt_block_cnt == 1 && torch_block_cnt == 2 );
403+ }
0 commit comments