@@ -11,8 +11,6 @@ namespace torch_tensorrt {
1111namespace core {
1212namespace runtime {
1313
14- typedef enum { ABI_TARGET_IDX = 0 , NAME_IDX, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
15-
1614std::string slugify (std::string s) {
1715 std::replace (s.begin (), s.end (), ' .' , ' _' );
1816 return s;
@@ -35,7 +33,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info) {
3533 std::string _name = serialized_info[NAME_IDX];
3634 std::string engine_info = serialized_info[ENGINE_IDX];
3735
38- CudaDevice cuda_device = deserialize_device (serialized_info[DEVICE_IDX]);
36+ CudaDevice cuda_device (serialized_info[DEVICE_IDX]);
3937 new (this ) TRTEngine (_name, engine_info, cuda_device);
4038}
4139
@@ -124,43 +122,6 @@ std::ostream& operator<<(std::ostream& os, const TRTEngine& engine) {
124122 return os;
125123}
126124
127- // TODO: Implement a call method
128- // c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
129- // auto input_vec = inputs.vec();
130- // auto output_vec = RunCudaEngine(exec_ctx, num_io, input_vec);
131- //
132- // return c10::List<at::Tensor>(output_vec);
133- // }
134-
135- namespace {
136- static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
137- torch::class_<TRTEngine>(" tensorrt" , " Engine" )
138- .def(torch::init<std::vector<std::string>>())
139- // TODO: .def("__call__", &TRTEngine::Run)
140- // TODO: .def("run", &TRTEngine::Run)
141- .def(" __str__" , &TRTEngine::to_str)
142- .def_pickle(
143- [](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
144- // Serialize TensorRT engine
145- auto serialized_trt_engine = self->cuda_engine ->serialize ();
146-
147- // Adding device info related meta data to the serialized file
148- auto trt_engine = std::string ((const char *)serialized_trt_engine->data (), serialized_trt_engine->size ());
149-
150- std::vector<std::string> serialize_info;
151- serialize_info.resize (ENGINE_IDX + 1 );
152-
153- serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
154- serialize_info[NAME_IDX] = self->name ;
155- serialize_info[DEVICE_IDX] = serialize_device (self->device_info );
156- serialize_info[ENGINE_IDX] = trt_engine;
157- return serialize_info;
158- },
159- [](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
160- return c10::make_intrusive<TRTEngine>(std::move (seralized_info));
161- });
162- } // namespace
163-
164125} // namespace runtime
165126} // namespace core
166127} // namespace torch_tensorrt
0 commit comments