@@ -120,6 +120,46 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) {
120120 ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 , 4 }}));
121121}
122122
123+ TEST (Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) {
124+ const auto graph = R"IR(
125+ graph(%0 : Tensor,
126+ %1 : Tensor,
127+ %2 : Tensor):
128+ %3 : int[] = prim::Constant[value=[-1, 5]]()
129+ %4 : int[] = prim::Constant[value=[-1]]()
130+ %5 : int = prim::Constant[value=2]()
131+ %6 : int = prim::Constant[value=4]()
132+ %7 : int = prim::Constant[value=5]()
133+ %8 : int = prim::Constant[value=0]()
134+ %9 : bool = prim::Constant[value=0]()
135+ %10 : NoneType = prim::Constant()
136+ %11 : int = prim::Constant[value=1]()
137+ %12: Tensor = aten::reshape(%1, %4)
138+ %13: Tensor = aten::reshape(%2, %3)
139+ %14: Tensor = aten::reshape(%1, %3)
140+ %15 : Tensor = aten::to(%12, %6, %9, %9, %10)
141+ %16 : int = aten::size(%1, %8)
142+ %17 : int[] = prim::ListConstruct(%16, %6, %5, %7)
143+ %18 : Tensor = aten::index_add_(%14, %8, %15, %13, %11)
144+ %20 : Tensor = aten::reshape(%18, %17)
145+ return (%20))IR" ;
146+
147+ auto g = std::make_shared<torch::jit::Graph>();
148+ torch::jit::parseIR (graph, g.get ());
149+
150+ torch_tensorrt::core::partitioning::PartitionInfo partition_info;
151+ partition_info.enabled = true ;
152+ partition_info.min_block_size = 3 ;
153+ std::unordered_map<torch::jit::Node*, int > fallback_nodes;
154+ std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
155+ torch_tensorrt::core::partitioning::segment_graph (g->block (), partition_info, fallback_nodes);
156+ ASSERT_TRUE (
157+ checkSegmentedBlockNumber (segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT , 1 ));
158+ ASSERT_TRUE (
159+ checkSegmentedBlockNumber (segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch , 1 ));
160+ ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 , 3 }, {4 , 5 , 6 , 7 }}));
161+ }
162+
123163TEST (Partitioning, SegmentSequentialModelWithForcedOPCorrectly) {
124164 const auto graph = R"IR(
125165 graph(%0 : Tensor,
0 commit comments