@@ -99,7 +99,7 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
9999 return nullptr ;
100100}
101101
102- torch::jit::Node* createCastNode (SegmentedBlock& seg_block, size_t index, bool is_input) {
102+ torch::jit::Node* createCastNode (SegmentedBlock& seg_block, size_t index, bool is_input, std::string device ) {
103103 auto cast_raw_value = is_input ? seg_block.raw_inputs ()[index] : seg_block.raw_outputs ()[index];
104104 auto cast_subgraph_value = is_input ? seg_block.inputs ()[index] : seg_block.outputs ()[index];
105105 torch::jit::Node* cast_node = getUpstreamCastNode (cast_raw_value);
@@ -125,8 +125,11 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i
125125 auto const_type = is_input ? g->insertConstant (4 ) : g->insertConstant (3 );
126126 auto const_zero = g->insertConstant (0 );
127127 const_zero->setType (torch::jit::BoolType::get ());
128+ auto cuda = g->insertConstant (device);
129+ cuda->setType (torch::jit::DeviceObjType::get ());
128130 auto none_val = g->insertNode (g->createNone ())->output ();
129- cast_node = g->create (torch::jit::aten::to, {cast_subgraph_value, const_type, const_zero, const_zero, none_val});
131+ cast_node =
132+ g->create (torch::jit::aten::to, {cast_subgraph_value, cuda, const_type, const_zero, const_zero, none_val});
130133 }
131134 return cast_node;
132135}
@@ -217,6 +220,8 @@ void getSegmentsOutputByRunning(
217220 ivalues_maps[output] = jit_results[idx++];
218221 }
219222
223+ auto target_device = partitioning_info.getGPUDeviceString ();
224+
220225 // auto int64 <=> int32 conversion
221226 if (seg_block.target () == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double ) {
222227 // First, check if there is Int64 input
@@ -226,7 +231,7 @@ void getSegmentsOutputByRunning(
226231 at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
227232 if (t == at::kLong ) {
228233 // we add a cast operation to cast the type to Int64
229- auto cast_node = createCastNode (seg_block, i, true );
234+ auto cast_node = createCastNode (seg_block, i, true , target_device );
230235 seg_block.g ()->prependNode (cast_node);
231236 seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
232237 }
@@ -237,7 +242,7 @@ void getSegmentsOutputByRunning(
237242 auto cur_ivalue = ivalues_maps[seg_block.raw_outputs ()[i]];
238243 at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
239244 if (t == at::kLong ) {
240- auto cast_node = createCastNode (seg_block, i, false );
245+ auto cast_node = createCastNode (seg_block, i, false , target_device );
241246 seg_block.g ()->appendNode (cast_node);
242247 seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
243248 }
0 commit comments