@@ -103,6 +103,7 @@ torch::jit::Node* createCastNode(
103103 SegmentedBlock& seg_block,
104104 size_t index,
105105 bool is_input,
106+ at::ScalarType dtype,
106107 std::string device,
107108 bool force_create_node = false ) {
108109 auto cast_raw_value = is_input ? seg_block.raw_inputs ()[index] : seg_block.raw_outputs ()[index];
@@ -115,7 +116,7 @@ torch::jit::Node* createCastNode(
115116 value_map.insert ({cast_node->inputs ()[0 ], cast_subgraph_value});
116117 if (!is_input) {
117118 // if this value is output, we need to cast it to int32
118- auto const_val = g->insertConstant (3 );
119+ auto const_val = g->insertConstant (dtype );
119120 if (cast_node->inputs ()[1 ]->node ()->output ()->type ()->kind () == torch::jit::TypeKind::DeviceObjType) {
120121 value_map.insert ({cast_node->inputs ()[2 ], const_val});
121122 } else {
@@ -127,7 +128,7 @@ torch::jit::Node* createCastNode(
127128 // auto cast_node = g->prependNode(g->createClone(cast_node, env));
128129 } else {
129130 // if there is no explicit cast aten::to operation, we need to create a node
130- auto const_type = is_input ? g->insertConstant (4 ) : g-> insertConstant ( 3 );
131+ auto const_type = g->insertConstant (dtype );
131132 auto const_zero = g->insertConstant (0 );
132133 const_zero->setType (torch::jit::BoolType::get ());
133134 auto cuda = g->insertConstant (device);
@@ -230,17 +231,28 @@ void getSegmentsOutputByRunning(
230231 // auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
231232 if (seg_block.target () == SegmentedBlock::kTorch ) {
232233 // First, check if there is Int64 input
233- if (partitioning_info.truncate_long_and_double ) {
234- for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
235- if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
236- auto cur_ivalue = ivalues_maps[seg_block.raw_inputs ()[i]];
237- at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
238- if (t == at::kLong ) {
239- // we add a cast operation to cast the type to Int64
240- auto cast_node = createCastNode (seg_block, i, true , target_device);
241- seg_block.g ()->prependNode (cast_node);
242- seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
243- }
234+ for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
235+ if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
236+ auto cur_ivalue = ivalues_maps[seg_block.raw_inputs ()[i]];
237+ at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
238+ if (t == at::kLong && partitioning_info.truncate_long_and_double ) {
239+ LOG_DEBUG (
240+ " Detected graph Long tensor input type during shape analysis, "
241+ << " inserting aten::to cast to Long to ensure this Torch block receives "
242+ << " a Long-type tensor input." );
243+ // we add a cast operation to cast the type to Int64
244+ auto cast_node = createCastNode (seg_block, i, true , at::kLong , target_device);
245+ seg_block.g ()->prependNode (cast_node);
246+ seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
247+ } else if (t == at::kByte && partitioning_info.cast_int8_inputs ) {
248+ LOG_DEBUG (
249+ " Detected graph Byte tensor input type during shape analysis, "
250+ << " inserting aten::to cast to Byte to ensure this Torch block receives "
251+ << " a Byte-type tensor input." );
252+ // If the input has type Byte, ensure it is casted to the correct type
253+ auto cast_node = createCastNode (seg_block, i, true , at::kByte , target_device, /* force_create_node=*/ true );
254+ seg_block.g ()->prependNode (cast_node);
255+ seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
244256 }
245257 }
246258 }
@@ -250,14 +262,22 @@ void getSegmentsOutputByRunning(
250262 auto cur_ivalue = ivalues_maps[seg_block.raw_outputs ()[i]];
251263 at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
252264
253- // If the input has type Long and truncation was requested, insert truncate
265+ // If the output has type Long and truncation was requested, insert truncate
254266 if (t == at::kLong && partitioning_info.truncate_long_and_double ) {
255- auto cast_node = createCastNode (seg_block, i, false , target_device);
267+ LOG_DEBUG (
268+ " Detected graph Long tensor output type during shape analysis, "
269+ << " inserting aten::to cast to Int to ensure the subsequent TensorRT block "
270+ << " receives an Int-type tensor input." );
271+ auto cast_node = createCastNode (seg_block, i, false , at::kInt , target_device);
256272 seg_block.g ()->appendNode (cast_node);
257273 seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
258274 } else if (t == at::kByte && partitioning_info.cast_int8_inputs ) {
259- // If the input has type Byte and truncation was requested, insert Integer cast
260- auto cast_node = createCastNode (seg_block, i, false , target_device, /* force_create_node=*/ true );
275+ LOG_DEBUG (
276+ " Detected graph Byte tensor output type during shape analysis, "
277+ << " inserting aten::to cast to Int to ensure the subsequent TensorRT block "
278+ << " receives an Int-type tensor input." );
279+ // If the output has type Byte and casting was requested, insert Integer cast
280+ auto cast_node = createCastNode (seg_block, i, false , at::kInt , target_device, /* force_create_node=*/ true );
261281 seg_block.g ()->appendNode (cast_node);
262282 seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
263283 }
0 commit comments