@@ -68,8 +68,8 @@ TRTEngine::TRTEngine(
6868 uint64_t inputs = 0 ;
6969 uint64_t outputs = 0 ;
7070
71- for (int64_t x = 0 ; x < cuda_engine->getNbBindings (); x ++) {
72- std::string bind_name = cuda_engine->getBindingName (x );
71+ for (int64_t trt_idx = 0 ; trt_idx < cuda_engine->getNbIOTensors (); trt_idx ++) {
72+ std::string bind_name = cuda_engine->getIOTensorName (trt_idx );
7373 LOG_DEBUG (" Binding name: " << bind_name);
7474 auto delim = bind_name.find (" ." );
7575 if (delim == std::string::npos) {
@@ -80,46 +80,45 @@ TRTEngine::TRTEngine(
8080 << bind_name
8181 << " \n Ensure module was compiled with Torch-TensorRT.ts or follows Torch-TensorRT Runtime conventions" );
8282 }
83-
8483 std::string idx_s = bind_name.substr (delim + 1 );
85- uint64_t idx = static_cast <uint64_t >(std::stoi (idx_s));
84+ uint64_t pyt_idx = static_cast <uint64_t >(std::stoi (idx_s));
8685
87- if (cuda_engine->bindingIsInput (x) ) {
86+ if (cuda_engine->getTensorIOMode (bind_name. c_str ()) == nvinfer1::TensorIOMode:: kINPUT ) {
8887 inputs++;
89- in_binding_map[x ] = idx ;
90- LOG_DEBUG (" TRT Binding: " << x << " : PYT Input: " << idx );
88+ in_binding_map[trt_idx ] = pyt_idx ;
89+ LOG_DEBUG (" TRT Binding index : " << trt_idx << " corresponds to PYT Input index : " << pyt_idx );
9190 } else {
9291 outputs++;
93- out_binding_map[x ] = idx ;
94- LOG_DEBUG (" TRT Binding: " << x << " : PYT Output: " << idx );
92+ out_binding_map[trt_idx ] = pyt_idx ;
93+ LOG_DEBUG (" TRT Binding index : " << trt_idx << " corresponds to PYT Output: " << pyt_idx );
9594 }
9695 }
9796
9897 num_io = std::make_pair (inputs, outputs);
9998 in_binding_names.resize (inputs);
10099 out_binding_names.resize (outputs);
101-
102- for (int64_t x = 0 ; x < cuda_engine->getNbBindings (); x++) {
103- std::string bind_name = cuda_engine->getBindingName (x);
104- if (cuda_engine->bindingIsInput (x)) {
100+ for (int64_t x = 0 ; x < cuda_engine->getNbIOTensors (); x++) {
101+ std::string bind_name = cuda_engine->getIOTensorName (x);
102+ if (cuda_engine->getTensorIOMode (bind_name.c_str ()) == nvinfer1::TensorIOMode::kINPUT ) {
105103 in_binding_names[in_binding_map.at (x)] = bind_name;
106104 } else {
107105 out_binding_names[out_binding_map.at (x)] = bind_name;
108106 }
109107 }
110108 } else {
111- uint64_t inputs = _in_binding_names.size ();
112- in_binding_names.resize (inputs );
113- for (size_t pyt_idx = 0 ; pyt_idx < inputs ; pyt_idx++) {
109+ uint64_t inputs_size = _in_binding_names.size ();
110+ in_binding_names.resize (inputs_size );
111+ for (size_t pyt_idx = 0 ; pyt_idx < inputs_size ; pyt_idx++) {
114112 auto binding_name = _in_binding_names[pyt_idx];
115113 auto trt_idx = cuda_engine->getBindingIndex (binding_name.c_str ());
116- TORCHTRT_CHECK ((trt_idx >= 0 ), " Could not find a TensorRT engine binding for input named " << binding_name );
114+ std::string engine_binded_name = cuda_engine-> getIOTensorName (pyt_idx );
117115 TORCHTRT_CHECK (
118- cuda_engine->bindingIsInput (trt_idx),
116+ (binding_name == engine_binded_name),
117+ " Could not find a TensorRT engine binding for input named " << binding_name);
118+ TORCHTRT_CHECK (
119+ (cuda_engine->getTensorIOMode (binding_name.c_str ()) == nvinfer1::TensorIOMode::kINPUT ),
119120 " Binding " << binding_name << " specified as input but found as output in TensorRT engine" );
120- LOG_DEBUG (
121- " Input binding name: " << binding_name << " (trt binding idx: " << trt_idx << " , "
122- << " pyt arg idx: " << pyt_idx << " )" );
121+ LOG_DEBUG (" Input binding name: " << binding_name << " pyt arg idx: " << pyt_idx << " )" );
123122 in_binding_map[trt_idx] = pyt_idx;
124123 in_binding_names[pyt_idx] = _in_binding_names[pyt_idx];
125124 }
@@ -129,17 +128,18 @@ TRTEngine::TRTEngine(
129128 for (size_t pyt_idx = 0 ; pyt_idx < outputs; pyt_idx++) {
130129 auto binding_name = _out_binding_names[pyt_idx];
131130 auto trt_idx = cuda_engine->getBindingIndex (binding_name.c_str ());
132- TORCHTRT_CHECK ((trt_idx >= 0 ), " Could not find a TensorRT engine binding for output named " << binding_name);
131+ std::string engine_binded_name = cuda_engine->getIOTensorName (inputs_size + pyt_idx);
132+ TORCHTRT_CHECK (
133+ (binding_name == engine_binded_name),
134+ " Could not find a TensorRT engine binding for output named " << binding_name);
133135 TORCHTRT_CHECK (
134- !cuda_engine->bindingIsInput (trt_idx ),
136+ !( cuda_engine->getTensorIOMode (binding_name. c_str ()) == nvinfer1::TensorIOMode:: kINPUT ),
135137 " Binding " << binding_name << " specified as output but found as input in TensorRT engine" );
136- LOG_DEBUG (
137- " Output binding name: " << binding_name << " (trt binding idx: " << trt_idx << " , "
138- << " pyt return idx: " << pyt_idx << " )" );
138+ LOG_DEBUG (" Output binding name: " << binding_name << " pyt return idx: " << inputs_size + pyt_idx << " )" );
139139 out_binding_map[trt_idx] = pyt_idx;
140140 out_binding_names[pyt_idx] = binding_name;
141141 }
142- num_io = std::make_pair (inputs , outputs);
142+ num_io = std::make_pair (inputs_size , outputs);
143143 }
144144
145145#ifndef NDEBUG
@@ -149,10 +149,10 @@ TRTEngine::TRTEngine(
149149}
150150
151151TRTEngine::~TRTEngine () {
152+ rt.reset ();
152153 trt_engine_profiler.reset ();
153154 exec_ctx.reset ();
154155 cuda_engine.reset ();
155- rt.reset ();
156156}
157157
158158void TRTEngine::disable_profiling () {
@@ -164,7 +164,7 @@ void TRTEngine::disable_profiling() {
164164}
165165
166166void TRTEngine::dump_engine_layer_info_to_file (const std::string& path) {
167- auto inspector = cuda_engine->createEngineInspector ();
167+ auto inspector = make_trt ( cuda_engine->createEngineInspector () );
168168 std::ofstream f (path);
169169 f << std::string (inspector->getEngineInformation (nvinfer1::LayerInformationFormat::kJSON ));
170170 f.close ();
@@ -208,23 +208,23 @@ std::string TRTEngine::to_str() const {
208208 std::stringstream ss;
209209 ss << " Torch-TensorRT TensorRT Engine:" << std::endl;
210210 ss << " Name: " << name << std::endl;
211- ss << " Bindings: { " << std::endl;
212- for (int64_t x = 0 ; x < cuda_engine-> getNbBindings (); x ++) {
213- if (cuda_engine-> bindingIsInput (x)) {
214- const uint64_t pyt_idx = in_binding_map. at (x) ;
215- ss << " ( " << x << " : " << in_binding_names. at (pyt_idx) << " ) Input: [ " << std::endl;
216- ss << " pytorch arg idx: " << pyt_idx << std::endl;
217- ss << " shape: " << exec_ctx-> getBindingDimensions (x) << std::endl;
218- ss << " dtype: " << util::TRTDataTypeToScalarType (exec_ctx-> getEngine (). getBindingDataType (x)) << std::endl;
219- ss << " ]" << std::endl;
220- } else {
221- const uint64_t pyt_idx = out_binding_map. at (x);
222- ss << " ( " << x << " : " << out_binding_names. at (pyt_idx) << " ) Output: [ " << std::endl;
223- ss << " pytorch return idx : " << pyt_idx << std::endl;
224- ss << " shape : " << exec_ctx-> getBindingDimensions (x) << std::endl;
225- ss << " dtype: " << util::TRTDataTypeToScalarType (exec_ctx-> getEngine (). getBindingDataType (x)) << std::endl;
226- ss << " ] " << std::endl;
227- }
211+ ss << " Inputs: [ " << std::endl;
212+ for (uint64_t i = 0 ; i < num_io. first ; i ++) {
213+ ss << " id: " << i << std::endl;
214+ ss << " shape: " << exec_ctx-> getTensorShape ( std::string ( " input_ " + str (i)). c_str ()) << std::endl ;
215+ ss << " dtype: "
216+ << util::TRTDataTypeToScalarType (exec_ctx-> getEngine (). getTensorDataType ( std::string ( " input_ " + str (i)). c_str ()))
217+ << std::endl;
218+ }
219+ ss << " ]" << std::endl;
220+ ss << " Outputs: [ " << std::endl;
221+ for ( uint64_t o = 0 ; o < num_io. second ; o++) {
222+ ss << " id : " << o << std::endl;
223+ ss << " shape : " << exec_ctx-> getTensorShape ( std::string ( " output_ " + str (o)). c_str ()) << std::endl;
224+ ss << " dtype : "
225+ << util::TRTDataTypeToScalarType (
226+ exec_ctx-> getEngine (). getTensorDataType ( std::string ( " output_ " + str (o)). c_str ()))
227+ << std::endl;
228228 }
229229 ss << " }" << std::endl;
230230 ss << " Device: " << device_info << std::endl;
0 commit comments