@@ -53,13 +53,25 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
5353 TORCHTRT_CHECK ((cuda_engine.get () != nullptr ), " Unable to deserialize the TensorRT engine" );
5454
5555 exec_ctx = make_trt (cuda_engine->createExecutionContext ());
56+ TORCHTRT_CHECK ((exec_ctx.get () != nullptr ), " Unable to create TensorRT execution context" );
5657
5758 uint64_t inputs = 0 ;
5859 uint64_t outputs = 0 ;
5960
6061 for (int64_t x = 0 ; x < cuda_engine->getNbBindings (); x++) {
6162 std::string bind_name = cuda_engine->getBindingName (x);
62- std::string idx_s = bind_name.substr (bind_name.find (" _" ) + 1 );
63+ LOG_DEBUG (" Binding name: " << bind_name);
64+ auto delim = bind_name.find (" ." );
65+ if (delim == std::string::npos) {
66+ delim = bind_name.find (" _" );
67+ TORCHTRT_CHECK (
68+ delim != std::string::npos,
69+ " Unable to determine binding index for input "
70+ << bind_name
71+ << " \n Ensure module was compiled with Torch-TensorRT.ts or follows Torch-TensorRT Runtime conventions" );
72+ }
73+
74+ std::string idx_s = bind_name.substr (delim + 1 );
6375 uint64_t idx = static_cast <uint64_t >(std::stoi (idx_s));
6476
6577 if (cuda_engine->bindingIsInput (x)) {
@@ -71,6 +83,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
7183 }
7284 }
7385 num_io = std::make_pair (inputs, outputs);
86+
87+ LOG_DEBUG (*this );
7488}
7589
7690TRTEngine& TRTEngine::operator =(const TRTEngine& other) {
@@ -82,6 +96,34 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
8296 return (*this );
8397}
8498
99+ std::string TRTEngine::to_str () const {
100+ std::stringstream ss;
101+ ss << " Torch-TensorRT TensorRT Engine:" << std::endl;
102+ ss << " Name: " << name << std::endl;
103+ ss << " Inputs: [" << std::endl;
104+ for (uint64_t i = 0 ; i < num_io.first ; i++) {
105+ ss << " id: " << i << std::endl;
106+ ss << " shape: " << exec_ctx->getBindingDimensions (i) << std::endl;
107+ ss << " dtype: " << util::TRTDataTypeToScalarType (exec_ctx->getEngine ().getBindingDataType (i)) << std::endl;
108+ }
109+ ss << " ]" << std::endl;
110+ ss << " Outputs: [" << std::endl;
111+ for (uint64_t o = 0 ; o < num_io.second ; o++) {
112+ ss << " id: " << o << std::endl;
113+ ss << " shape: " << exec_ctx->getBindingDimensions (o) << std::endl;
114+ ss << " dtype: " << util::TRTDataTypeToScalarType (exec_ctx->getEngine ().getBindingDataType (o)) << std::endl;
115+ }
116+ ss << " ]" << std::endl;
117+ ss << " Device: " << device_info << std::endl;
118+
119+ return ss.str ();
120+ }
121+
122+ std::ostream& operator <<(std::ostream& os, const TRTEngine& engine) {
123+ os << engine.to_str ();
124+ return os;
125+ }
126+
85127// TODO: Implement a call method
86128// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
87129// auto input_vec = inputs.vec();
@@ -96,6 +138,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
96138 .def(torch::init<std::vector<std::string>>())
97139 // TODO: .def("__call__", &TRTEngine::Run)
98140 // TODO: .def("run", &TRTEngine::Run)
141+ .def(" __str__" , &TRTEngine::to_str)
99142 .def_pickle(
100143 [](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
101144 // Serialize TensorRT engine
0 commit comments