1313# Set the USE_CXX11=1 to use cxx11_abi
1414USE_CXX11 = 0 if not 'USE_CXX11' in os .environ else os .environ ["USE_CXX11" ]
1515
16+ # Set the USE_HOST_DEPS=1 to use host dependencies for tests
17+ USE_HOST_DEPS = 0 if not 'USE_HOST_DEPS' in os .environ else os .environ ["USE_HOST_DEPS" ]
18+
1619SUPPORTED_PYTHON_VERSIONS = ["3.7" , "3.8" , "3.9" , "3.10" ]
1720
1821nox .options .sessions = ["l0_api_tests-" + "{}.{}" .format (sys .version_info .major , sys .version_info .minor )]
@@ -22,15 +25,14 @@ def install_deps(session):
2225 session .install ("-r" , os .path .join (TOP_DIR , "py" , "requirements.txt" ))
2326 session .install ("-r" , os .path .join (TOP_DIR , "tests" , "py" , "requirements.txt" ))
2427
25- def download_models (session , use_host_env = False ):
28+ def download_models (session ):
2629 print ("Downloading test models" )
2730 session .install ("-r" , os .path .join (TOP_DIR , "tests" , "modules" , "requirements.txt" ))
2831 print (TOP_DIR )
2932 session .chdir (os .path .join (TOP_DIR , "tests" , "modules" ))
30- if use_host_env :
33+ if USE_HOST_DEPS :
3134 session .run_always ('python' , 'hub.py' , env = {'PYTHONPATH' : PYT_PATH })
3235 else :
33- session .install ("-r" , os .path .join (TOP_DIR , "py" , "requirements.txt" ))
3436 session .run_always ('python' , 'hub.py' )
3537
3638def install_torch_trt (session ):
@@ -54,9 +56,9 @@ def download_datasets(session):
5456 os .path .join (TOP_DIR , 'tests/accuracy/datasets/data/cidar-10-batches-bin' ),
5557 external = True )
5658
57- def train_model (session , use_host_env = False ):
59+ def train_model (session ):
5860 session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
59- if use_host_env :
61+ if USE_HOST_DEPS :
6062 session .run_always ('python' ,
6163 'main.py' ,
6264 '--lr' , '0.01' ,
@@ -83,12 +85,12 @@ def train_model(session, use_host_env=False):
8385 'export_ckpt.py' ,
8486 'vgg16_ckpts/ckpt_epoch25.pth' )
8587
86- def finetune_model (session , use_host_env = False ):
88+ def finetune_model (session ):
8789 # Install pytorch-quantization dependency
8890 session .install ('pytorch-quantization' , '--extra-index-url' , 'https://pypi.ngc.nvidia.com' )
8991 session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
9092
91- if use_host_env :
93+ if USE_HOST_DEPS :
9294 session .run_always ('python' ,
9395 'finetune_qat.py' ,
9496 '--lr' , '0.01' ,
@@ -134,25 +136,25 @@ def cleanup(session):
134136 str ('rm -rf ' ) + target ,
135137 external = True )
136138
137- def run_base_tests (session , use_host_env = False ):
139+ def run_base_tests (session ):
138140 print ("Running basic tests" )
139141 session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
140142 tests = [
141143 "test_api.py" ,
142144 "test_to_backend_api.py" ,
143145 ]
144146 for test in tests :
145- if use_host_env :
147+ if USE_HOST_DEPS :
146148 session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
147149 else :
148150 session .run_always ("python" , test )
149151
150- def run_accuracy_tests (session , use_host_env = False ):
152+ def run_accuracy_tests (session ):
151153 print ("Running accuracy tests" )
152154 session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
153155 tests = []
154156 for test in tests :
155- if use_host_env :
157+ if USE_HOST_DEPS :
156158 session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
157159 else :
158160 session .run_always ("python" , test )
@@ -170,7 +172,7 @@ def copy_model(session):
170172 os .path .join (TOP_DIR , str ('tests/py/' ) + file_name ),
171173 external = True )
172174
173- def run_int8_accuracy_tests (session , use_host_env = False ):
175+ def run_int8_accuracy_tests (session ):
174176 print ("Running accuracy tests" )
175177 copy_model (session )
176178 session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
@@ -180,12 +182,12 @@ def run_int8_accuracy_tests(session, use_host_env=False):
180182 "test_qat_trt_accuracy.py" ,
181183 ]
182184 for test in tests :
183- if use_host_env :
185+ if USE_HOST_DEPS :
184186 session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
185187 else :
186188 session .run_always ("python" , test )
187189
188- def run_trt_compatibility_tests (session , use_host_env = False ):
190+ def run_trt_compatibility_tests (session ):
189191 print ("Running TensorRT compatibility tests" )
190192 copy_model (session )
191193 session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
@@ -194,151 +196,121 @@ def run_trt_compatibility_tests(session, use_host_env=False):
194196 "test_ptq_trt_calibrator.py" ,
195197 ]
196198 for test in tests :
197- if use_host_env :
199+ if USE_HOST_DEPS :
198200 session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
199201 else :
200202 session .run_always ("python" , test )
201203
202- def run_dla_tests (session , use_host_env = False ):
204+ def run_dla_tests (session ):
203205 print ("Running DLA tests" )
204206 session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
205207 tests = [
206208 "test_api_dla.py" ,
207209 ]
208210 for test in tests :
209- if use_host_env :
211+ if USE_HOST_DEPS :
210212 session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
211213 else :
212214 session .run_always ("python" , test )
213215
214- def run_multi_gpu_tests (session , use_host_env = False ):
216+ def run_multi_gpu_tests (session ):
215217 print ("Running multi GPU tests" )
216218 session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
217219 tests = [
218220 "test_multi_gpu.py" ,
219221 ]
220222 for test in tests :
221- if use_host_env :
223+ if USE_HOST_DEPS :
222224 session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
223225 else :
224226 session .run_always ("python" , test )
225227
226- def run_l0_api_tests (session , use_host_env = False ):
227- if not use_host_env :
228+ def run_l0_api_tests (session ):
229+ if not USE_HOST_DEPS :
228230 install_deps (session )
229231 install_torch_trt (session )
230- download_models (session , use_host_env )
231- run_base_tests (session , use_host_env )
232+ download_models (session )
233+ run_base_tests (session )
232234 cleanup (session )
233235
234- def run_l0_dla_tests (session , use_host_env = False ):
235- if not use_host_env :
236+ def run_l0_dla_tests (session ):
237+ if not USE_HOST_DEPS :
236238 install_deps (session )
237239 install_torch_trt (session )
238- download_models (session , use_host_env )
239- run_base_tests (session , use_host_env )
240+ download_models (session )
241+ run_base_tests (session )
240242 cleanup (session )
241243
242- def run_l1_accuracy_tests (session , use_host_env = False ):
243- if not use_host_env :
244+ def run_l1_accuracy_tests (session ):
245+ if not USE_HOST_DEPS :
244246 install_deps (session )
245247 install_torch_trt (session )
246- download_models (session , use_host_env )
248+ download_models (session )
247249 download_datasets (session )
248- train_model (session , use_host_env )
249- run_accuracy_tests (session , use_host_env )
250+ train_model (session )
251+ run_accuracy_tests (session )
250252 cleanup (session )
251253
252- def run_l1_int8_accuracy_tests (session , use_host_env = False ):
253- if not use_host_env :
254+ def run_l1_int8_accuracy_tests (session ):
255+ if not USE_HOST_DEPS :
254256 install_deps (session )
255257 install_torch_trt (session )
256- download_models (session , use_host_env )
258+ download_models (session )
257259 download_datasets (session )
258- train_model (session , use_host_env )
259- finetune_model (session , use_host_env )
260- run_int8_accuracy_tests (session , use_host_env )
260+ train_model (session )
261+ finetune_model (session )
262+ run_int8_accuracy_tests (session )
261263 cleanup (session )
262264
263- def run_l2_trt_compatibility_tests (session , use_host_env = False ):
264- if not use_host_env :
265+ def run_l2_trt_compatibility_tests (session ):
266+ if not USE_HOST_DEPS :
265267 install_deps (session )
266268 install_torch_trt (session )
267- download_models (session , use_host_env )
269+ download_models (session )
268270 download_datasets (session )
269- train_model (session , use_host_env )
270- run_trt_compatibility_tests (session , use_host_env )
271+ train_model (session )
272+ run_trt_compatibility_tests (session )
271273 cleanup (session )
272274
273- def run_l2_multi_gpu_tests (session , use_host_env = False ):
274- if not use_host_env :
275+ def run_l2_multi_gpu_tests (session ):
276+ if not USE_HOST_DEPS :
275277 install_deps (session )
276278 install_torch_trt (session )
277- download_models (session , use_host_env )
278- run_multi_gpu_tests (session , use_host_env )
279+ download_models (session )
280+ run_multi_gpu_tests (session )
279281 cleanup (session )
280282
281283@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
282284def l0_api_tests (session ):
283285 """When a developer needs to check correctness for a PR or something"""
284- run_l0_api_tests (session , use_host_env = False )
285-
286- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
287- def l0_api_tests_host_deps (session ):
288- """When a developer needs to check basic api functionality using host dependencies"""
289- run_l0_api_tests (session , use_host_env = True )
286+ run_l0_api_tests (session )
290287
291288@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
292- def l0_dla_tests_host_deps (session ):
289+ def l0_dla_tests (session ):
293290 """When a developer needs to check basic api functionality using host dependencies"""
294- run_l0_dla_tests (session , use_host_env = True )
291+ run_l0_dla_tests (session )
295292
296293@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
297294def l1_accuracy_tests (session ):
298295 """Checking accuracy performance on various usecases"""
299- run_l1_accuracy_tests (session , use_host_env = False )
300-
301- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
302- def l1_accuracy_tests_host_deps (session ):
303- """Checking accuracy performance on various usecases using host dependencies"""
304- run_l1_accuracy_tests (session , use_host_env = True )
296+ run_l1_accuracy_tests (session )
305297
306298@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
307299def l1_int8_accuracy_tests (session ):
308300 """Checking accuracy performance on various usecases"""
309- run_l1_int8_accuracy_tests (session , use_host_env = False )
310-
311- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
312- def l1_int8_accuracy_tests_host_deps (session ):
313- """Checking accuracy performance on various usecases using host dependencies"""
314- run_l1_int8_accuracy_tests (session , use_host_env = True )
301+ run_l1_int8_accuracy_tests (session )
315302
316303@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
317304def l2_trt_compatibility_tests (session ):
318305 """Makes sure that TensorRT Python and Torch-TensorRT can work together"""
319- run_l2_trt_compatibility_tests (session , use_host_env = False )
320-
321- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
322- def l2_trt_compatibility_tests_host_deps (session ):
323- """Makes sure that TensorRT Python and Torch-TensorRT can work together using host dependencies"""
324- run_l2_trt_compatibility_tests (session , use_host_env = True )
306+ run_l2_trt_compatibility_tests (session )
325307
326308@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
327309def l2_multi_gpu_tests (session ):
328310 """Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
329- run_l2_multi_gpu_tests (session , use_host_env = False )
330-
331- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
332- def l2_multi_gpu_tests_host_deps (session ):
333- """Makes sure that Torch-TensorRT can operate on multi-gpu systems using host dependencies"""
334- run_l2_multi_gpu_tests (session , use_host_env = True )
311+ run_l2_multi_gpu_tests (session )
335312
336313@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
337314def download_test_models (session ):
338315 """Grab all the models needed for testing"""
339- download_models (session , use_host_env = False )
340-
341- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
342- def download_test_models_host_deps (session ):
343- """Grab all the models needed for testing using host dependencies"""
344- download_models (session , use_host_env = True )
316+ download_models (session )
0 commit comments