@@ -589,3 +589,89 @@ TEST(LoweringPasses, RemoveAtenIntConstTensorValuesAgree) {
589589 // Validate identical graphs after pooling constants and canonicalizing
590590 ASSERT_TRUE ((tg->toString () == sg->toString ()));
591591}
592+
593+ TEST (LoweringPasses, RemoveCollectionCastTuple) {
594+ // Ensure the lowering pass transforms the first graph into the second
595+ std::string source_graph = R"IR(
596+ graph(%x.1 : Tensor):
597+ %3 : int = prim::Constant[value=1]()
598+ %2 : int = prim::Constant[value=2]()
599+ %a.1 : Tensor = aten::mul(%x.1, %2)
600+ %b.1 : Tensor = aten::add(%a.1, %2, %3)
601+ %c.1 : Tensor = aten::relu(%b.1)
602+ %d.1 : Tensor = aten::sqrt(%c.1)
603+ %8 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%c.1, %d.1, %b.1)
604+ return (%8))IR" ;
605+
606+ std::string target_graph = R"IR(
607+ graph(%x.1 : Tensor):
608+ %3 : int = prim::Constant[value=1]()
609+ %2 : int = prim::Constant[value=2]()
610+ %a.1 : Tensor = aten::mul(%x.1, %2)
611+ %b.1 : Tensor = aten::add(%a.1, %2, %3)
612+ %c.1 : Tensor = aten::relu(%b.1)
613+ %d.1 : Tensor = aten::sqrt(%c.1)
614+ return (%c.1, %d.1, %b.1))IR" ;
615+
616+ // Ensure the lowering pass transforms the first graph into the second
617+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
618+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
619+ auto sg = std::make_shared<torch::jit::Graph>();
620+ torch::jit::parseIR (source_graph, sg.get ());
621+
622+ torch_tensorrt::core::lowering::passes::RemoveCollectionCast (sg);
623+ torch::jit::ConstantPooling (sg);
624+ sg = torch::jit::Canonicalize (sg, false );
625+
626+ auto tg = std::make_shared<torch::jit::Graph>();
627+ torch::jit::parseIR (target_graph, tg.get ());
628+
629+ torch::jit::ConstantPooling (tg);
630+ tg = torch::jit::Canonicalize (tg, false );
631+
632+ // Validate identical graphs after pooling constants and canonicalizing
633+ ASSERT_TRUE ((tg->toString () == sg->toString ()));
634+ }
635+
636+ TEST (LoweringPasses, RemoveCollectionCastList) {
637+ // Ensure the lowering pass transforms the first graph into the second
638+ std::string source_graph = R"IR(
639+ graph(%x.1 : Tensor):
640+ %3 : int = prim::Constant[value=1]()
641+ %2 : int = prim::Constant[value=2]()
642+ %a.1 : Tensor = aten::mul(%x.1, %2)
643+ %b.1 : Tensor = aten::add(%a.1, %2, %3)
644+ %c.1 : Tensor = aten::relu(%b.1)
645+ %d.1 : Tensor = aten::sqrt(%c.1)
646+ %8 : (Tensor, Tensor, Tensor) = prim::ListConstruct(%b.1, %c.1, %d.1)
647+ return (%8))IR" ;
648+
649+ std::string target_graph = R"IR(
650+ graph(%x.1 : Tensor):
651+ %3 : int = prim::Constant[value=1]()
652+ %2 : int = prim::Constant[value=2]()
653+ %a.1 : Tensor = aten::mul(%x.1, %2)
654+ %b.1 : Tensor = aten::add(%a.1, %2, %3)
655+ %c.1 : Tensor = aten::relu(%b.1)
656+ %d.1 : Tensor = aten::sqrt(%c.1)
657+ return (%b.1, %c.1, %d.1))IR" ;
658+
659+ // Ensure the lowering pass transforms the first graph into the second
660+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
661+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
662+ auto sg = std::make_shared<torch::jit::Graph>();
663+ torch::jit::parseIR (source_graph, sg.get ());
664+
665+ torch_tensorrt::core::lowering::passes::RemoveCollectionCast (sg);
666+ torch::jit::ConstantPooling (sg);
667+ sg = torch::jit::Canonicalize (sg, false );
668+
669+ auto tg = std::make_shared<torch::jit::Graph>();
670+ torch::jit::parseIR (target_graph, tg.get ());
671+
672+ torch::jit::ConstantPooling (tg);
673+ tg = torch::jit::Canonicalize (tg, false );
674+
675+ // Validate identical graphs after pooling constants and canonicalizing
676+ ASSERT_TRUE ((tg->toString () == sg->toString ()));
677+ }
0 commit comments