55
66# Use system installed Python packages
77PYT_PATH = '/opt/conda/lib/python3.8/site-packages' if not 'PYT_PATH' in os .environ else os .environ ["PYT_PATH" ]
8+ print (f"Using python path { PYT_PATH } " )
89
910# Set the root directory to the directory of the noxfile unless the user wants to
1011# TOP_DIR
1112TOP_DIR = os .path .dirname (os .path .realpath (__file__ )) if not 'TOP_DIR' in os .environ else os .environ ["TOP_DIR" ]
13+ print (f"Test root directory { TOP_DIR } " )
1214
1315# Set the USE_CXX11=1 to use cxx11_abi
1416USE_CXX11 = 0 if not 'USE_CXX11' in os .environ else os .environ ["USE_CXX11" ]
17+ if USE_CXX11 :
18+ print ("Using cxx11 abi" )
1519
1620# Set the USE_HOST_DEPS=1 to use host dependencies for tests
1721USE_HOST_DEPS = 0 if not 'USE_HOST_DEPS' in os .environ else os .environ ["USE_HOST_DEPS" ]
22+ if USE_HOST_DEPS :
23+ print ("Using dependencies from host python" )
1824
1925SUPPORTED_PYTHON_VERSIONS = ["3.7" , "3.8" , "3.9" , "3.10" ]
2026
@@ -58,6 +64,12 @@ def download_datasets(session):
5864
5965def train_model (session ):
6066 session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
67+ session .install ("-r" , "requirements.txt" )
68+ if os .path .exists ('vgg16_ckpts/ckpt_epoch25.pth' ):
69+ session .run_always ('python' ,
70+ 'export_ckpt.py' ,
71+ 'vgg16_ckpts/ckpt_epoch25.pth' )
72+ return
6173 if USE_HOST_DEPS :
6274 session .run_always ('python' ,
6375 'main.py' ,
@@ -140,14 +152,14 @@ def run_base_tests(session):
140152 print ("Running basic tests" )
141153 session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
142154 tests = [
143- "test_api.py " ,
144- "test_to_backend_api.py" ,
155+ "api " ,
156+ "integrations/ test_to_backend_api.py" ,
145157 ]
146158 for test in tests :
147159 if USE_HOST_DEPS :
148- session .run_always ('python ' , test , env = {'PYTHONPATH' : PYT_PATH })
160+ session .run_always ('pytest ' , test , env = {'PYTHONPATH' : PYT_PATH })
149161 else :
150- session .run_always ("python " , test )
162+ session .run_always ("pytest " , test )
151163
152164def run_accuracy_tests (session ):
153165 print ("Running accuracy tests" )
@@ -169,23 +181,23 @@ def copy_model(session):
169181 session .run_always ('cp' ,
170182 '-rpf' ,
171183 os .path .join (TOP_DIR , src_file ),
172- os .path .join (TOP_DIR , str ('tests/py /' ) + file_name ),
184+ os .path .join (TOP_DIR , str ('tests/modules /' ) + file_name ),
173185 external = True )
174186
175187def run_int8_accuracy_tests (session ):
176188 print ("Running accuracy tests" )
177189 copy_model (session )
178190 session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
179191 tests = [
180- "test_ptq_dataloader_calibrator .py" ,
181- "test_ptq_to_backend .py" ,
182- "test_qat_trt_accuracy.py " ,
192+ "ptq/test_ptq_to_backend .py" ,
193+ "ptq/test_ptq_dataloader_calibrator .py" ,
194+ "qat/ " ,
183195 ]
184196 for test in tests :
185197 if USE_HOST_DEPS :
186- session .run_always ('python ' , test , env = {'PYTHONPATH' : PYT_PATH })
198+ session .run_always ('pytest ' , test , env = {'PYTHONPATH' : PYT_PATH })
187199 else :
188- session .run_always ("python " , test )
200+ session .run_always ("pytest " , test )
189201
190202def run_trt_compatibility_tests (session ):
191203 print ("Running TensorRT compatibility tests" )
@@ -197,9 +209,9 @@ def run_trt_compatibility_tests(session):
197209 ]
198210 for test in tests :
199211 if USE_HOST_DEPS :
200- session .run_always ('python ' , test , env = {'PYTHONPATH' : PYT_PATH })
212+ session .run_always ('pytest ' , test , env = {'PYTHONPATH' : PYT_PATH })
201213 else :
202- session .run_always ("python " , test )
214+ session .run_always ("pytest " , test )
203215
204216def run_dla_tests (session ):
205217 print ("Running DLA tests" )
@@ -209,9 +221,9 @@ def run_dla_tests(session):
209221 ]
210222 for test in tests :
211223 if USE_HOST_DEPS :
212- session .run_always ('python ' , test , env = {'PYTHONPATH' : PYT_PATH })
224+ session .run_always ('pytest ' , test , env = {'PYTHONPATH' : PYT_PATH })
213225 else :
214- session .run_always ("python " , test )
226+ session .run_always ("pytest " , test )
215227
216228def run_multi_gpu_tests (session ):
217229 print ("Running multi GPU tests" )
@@ -221,9 +233,9 @@ def run_multi_gpu_tests(session):
221233 ]
222234 for test in tests :
223235 if USE_HOST_DEPS :
224- session .run_always ('python ' , test , env = {'PYTHONPATH' : PYT_PATH })
236+ session .run_always ('pytest ' , test , env = {'PYTHONPATH' : PYT_PATH })
225237 else :
226- session .run_always ("python " , test )
238+ session .run_always ("pytest " , test )
227239
228240def run_l0_api_tests (session ):
229241 if not USE_HOST_DEPS :
@@ -245,7 +257,6 @@ def run_l1_accuracy_tests(session):
245257 if not USE_HOST_DEPS :
246258 install_deps (session )
247259 install_torch_trt (session )
248- download_models (session )
249260 download_datasets (session )
250261 train_model (session )
251262 run_accuracy_tests (session )
@@ -255,7 +266,6 @@ def run_l1_int8_accuracy_tests(session):
255266 if not USE_HOST_DEPS :
256267 install_deps (session )
257268 install_torch_trt (session )
258- download_models (session )
259269 download_datasets (session )
260270 train_model (session )
261271 finetune_model (session )
@@ -313,4 +323,8 @@ def l2_multi_gpu_tests(session):
313323@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
314324def download_test_models (session ):
315325 """Grab all the models needed for testing"""
326+ try :
327+ import torch
328+ except ModuleNotFoundError :
329+ install_deps (session )
316330 download_models (session )
0 commit comments