@@ -99,13 +99,18 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
9999 return nullptr ;
100100}
101101
102- torch::jit::Node* createCastNode (SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) {
102+ torch::jit::Node* createCastNode (
103+ SegmentedBlock& seg_block,
104+ size_t index,
105+ bool is_input,
106+ std::string device,
107+ bool force_create_node = false ) {
103108 auto cast_raw_value = is_input ? seg_block.raw_inputs ()[index] : seg_block.raw_outputs ()[index];
104109 auto cast_subgraph_value = is_input ? seg_block.inputs ()[index] : seg_block.outputs ()[index];
105110 torch::jit::Node* cast_node = getUpstreamCastNode (cast_raw_value);
106111 auto g = seg_block.g ();
107112 // if we can find upstream aten::to node, we use it's parameters for creating new cast node
108- if (cast_node) {
113+ if (cast_node && !force_create_node ) {
109114 std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map;
110115 value_map.insert ({cast_node->inputs ()[0 ], cast_subgraph_value});
111116 if (!is_input) {
@@ -222,29 +227,39 @@ void getSegmentsOutputByRunning(
222227
223228 auto target_device = partitioning_info.getGPUDeviceString ();
224229
225- // auto int64 <=> int32 conversion
226- if (seg_block.target () == SegmentedBlock::kTorch && partitioning_info. truncate_long_and_double ) {
230+ // auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
231+ if (seg_block.target () == SegmentedBlock::kTorch ) {
227232 // First, check if there is Int64 input
228- for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
229- if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
230- auto cur_ivalue = ivalues_maps[seg_block.raw_inputs ()[i]];
231- at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
232- if (t == at::kLong ) {
233- // we add a cast operation to cast the type to Int64
234- auto cast_node = createCastNode (seg_block, i, true , target_device);
235- seg_block.g ()->prependNode (cast_node);
236- seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
233+ if (partitioning_info.truncate_long_and_double ) {
234+ for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
235+ if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
236+ auto cur_ivalue = ivalues_maps[seg_block.raw_inputs ()[i]];
237+ at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
238+ if (t == at::kLong ) {
239+ // we add a cast operation to cast the type to Int64
240+ auto cast_node = createCastNode (seg_block, i, true , target_device);
241+ seg_block.g ()->prependNode (cast_node);
242+ seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
243+ }
237244 }
238245 }
239246 }
247+
240248 for (size_t i = 0 ; i < seg_block.outputs ().size (); ++i) {
241249 if (ivalues_maps[seg_block.raw_outputs ()[i]].isTensor ()) {
242250 auto cur_ivalue = ivalues_maps[seg_block.raw_outputs ()[i]];
243251 at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
244- if (t == at::kLong ) {
252+
253+ // If the input has type Long and truncation was requested, insert truncate
254+ if (t == at::kLong && partitioning_info.truncate_long_and_double ) {
245255 auto cast_node = createCastNode (seg_block, i, false , target_device);
246256 seg_block.g ()->appendNode (cast_node);
247257 seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
258+ } else if (t == at::kByte && partitioning_info.cast_int8_inputs ) {
259+ // If the input has type Byte and truncation was requested, insert Integer cast
260+ auto cast_node = createCastNode (seg_block, i, false , target_device, /* force_create_node=*/ true );
261+ seg_block.g ()->appendNode (cast_node);
262+ seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
248263 }
249264 }
250265 }
@@ -254,11 +269,13 @@ void getSegmentsOutputByRunning(
254269 std::vector<std::vector<int64_t >> input_shapes;
255270 std::vector<at::ScalarType> input_types;
256271 for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
257- if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
272+ auto current_input = seg_block.raw_inputs ()[i];
273+
274+ if (ivalues_maps[current_input].isTensor ()) {
258275 // set the input_shape and data_type
259276 // we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
260277 // shape inference
261- auto cur_ivalue = ivalues_maps[seg_block. raw_inputs ()[i] ];
278+ auto cur_ivalue = ivalues_maps[current_input ];
262279 at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
263280
264281 if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble )) {
@@ -271,10 +288,16 @@ void getSegmentsOutputByRunning(
271288 cur_ivalue = cur_ivalue.toTensor ().to (at::kFloat );
272289 LOG_WARNING (" Truncating graph input type from at::kDouble to at::kFloat" );
273290 }
291+
274292 c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType (cur_ivalue.toTensor ().dtype ());
275293 if (dtype == c10::nullopt ) {
276294 TORCHTRT_THROW_ERROR (" Unsupported input data type " << cur_ivalue.toTensor ().dtype ());
295+ } else if (dtype && dtype.value () == nvinfer1::DataType::kINT8 && partitioning_info.cast_int8_inputs ) {
296+ // Special case to ensure input IValues to TensorRT engine are not Int8 type if the
297+ // model itself is not quantized
298+ cur_ivalue = cur_ivalue.toTensor ().to (at::kInt );
277299 }
300+
278301 if (cur_ivalue.toTensor ().sizes ().size () == 0 ) {
279302 // handle Scalar types, which has sizes of []
280303 input_shapes.push_back (util::toVec (util::toDims (c10::List<int64_t >({1 }))));
@@ -297,6 +320,7 @@ void runShapeAnalysis(
297320 const ir::ShapeMode& shape_mode) {
298321 // register every segment's input shape, and it's running output IValues
299322 for (auto & seg_block : ctx->partitioned_blocks [block]) {
323+ LOG_GRAPH (" Running shape analysis on block " << seg_block);
300324 torch::jit::ConstantPooling (seg_block.g ());
301325 getSegmentsOutputByRunning (seg_block, example_tensor_map, ctx->settings , shape_mode);
302326 }
0 commit comments