@@ -211,6 +211,18 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec:
211211 assert type (compile_spec ["workspace_size" ]) is int
212212 info .workspace_size = compile_spec ["workspace_size" ]
213213
214+ if "dla_sram_size" in compile_spec :
215+ assert type (compile_spec ["dla_sram_size" ]) is int
216+ info .dla_sram_size = compile_spec ["dla_sram_size" ]
217+
218+ if "dla_local_dram_size" in compile_spec :
219+ assert type (compile_spec ["dla_local_dram_size" ]) is int
220+ info .dla_local_dram_size = compile_spec ["dla_local_dram_size" ]
221+
222+ if "dla_global_dram_size" in compile_spec :
223+ assert type (compile_spec ["dla_global_dram_size" ]) is int
224+ info .dla_global_dram_size = compile_spec ["dla_global_dram_size" ]
225+
214226 if "truncate_long_and_double" in compile_spec :
215227 assert type (compile_spec ["truncate_long_and_double" ]) is bool
216228 info .truncate_long_and_double = compile_spec ["truncate_long_and_double" ]
@@ -229,9 +241,11 @@ def TensorRTCompileSpec(inputs=[],
229241 refit = False ,
230242 debug = False ,
231243 capability = _enums .EngineCapability .default ,
232- num_min_timing_iters = 2 ,
233244 num_avg_timing_iters = 1 ,
234245 workspace_size = 0 ,
246+ dla_sram_size = 1048576 ,
247+ dla_local_dram_size = 1073741824 ,
248+ dla_global_dram_size = 536870912 ,
235249 truncate_long_and_double = False ,
236250 calibrator = None ) -> torch .classes .tensorrt .CompileSpec :
237251 """Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
@@ -263,7 +277,6 @@ def TensorRTCompileSpec(inputs=[],
263277 refit (bool): Enable refitting
264278 debug (bool): Enable debuggable engine
265279 capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
266- num_min_timing_iters (int): Number of minimization timing iterations used to select kernels
267280 num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
268281 workspace_size (int): Maximum size of workspace given to TensorRT
269282 truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
@@ -283,9 +296,11 @@ def TensorRTCompileSpec(inputs=[],
283296 "refit" : refit , # enable refit
284297 "debug" : debug , # enable debuggable engine
285298 "capability" : capability , # Restrict kernel selection to safe gpu kernels or safe dla kernels
286- "num_min_timing_iters" : num_min_timing_iters , # Number of minimization timing iterations used to select kernels
287299 "num_avg_timing_iters" : num_avg_timing_iters , # Number of averaging timing iterations used to select kernels
288300 "workspace_size" : workspace_size , # Maximum size of workspace given to TensorRT
301+ "dla_sram_size" : dla_sram_size , # Fast software managed RAM used by DLA to communicate within a layer.
302+ "dla_local_dram_size" : dla_local_dram_size , # Host RAM used by DLA to share intermediate tensor data across operations
303+ "dla_global_dram_size" : dla_global_dram_size , # Host RAM used by DLA to store weights and metadata for execution
289304 "calibrator" : calibrator ,
290305 "truncate_long_and_double" : truncate_long_and_double
291306 }
@@ -331,9 +346,11 @@ def TensorRTCompileSpec(inputs=[],
331346 backend_spec ._set_debug (parsed_spec .debug )
332347 backend_spec ._set_refit (parsed_spec .refit )
333348 backend_spec ._set_capability (int (parsed_spec .capability ))
334- backend_spec ._set_num_min_timing_iters (parsed_spec .num_min_timing_iters )
335349 backend_spec ._set_num_avg_timing_iters (parsed_spec .num_avg_timing_iters )
336350 backend_spec ._set_workspace_size (parsed_spec .workspace_size )
351+ backend_spec ._set_dla_sram_size (parsed_spec .dla_sram_size )
352+ backend_spec ._set_dla_local_dram_size (parsed_spec .dla_local_dram_size )
353+ backend_spec ._set_dla_global_dram_size (parsed_spec ._set_dla_global_dram_size )
337354 backend_spec ._set_truncate_long_and_double (parsed_spec .truncate_long_and_double )
338355 backend_spec ._set_ptq_calibrator (parsed_spec ._get_calibrator_handle ())
339356
0 commit comments