11import torch
22import timm
33import pytest
4+ import unittest
45
56import torch_tensorrt as torchtrt
67import torchvision .models as models
1213 cosine_similarity ,
1314)
1415
16+ assertions = unittest .TestCase ()
17+
1518
1619@pytest .mark .unit
1720def test_resnet18 (ir ):
@@ -31,9 +34,9 @@ def test_resnet18(ir):
3134
3235 trt_mod = torchtrt .compile (model , ** compile_spec )
3336 cos_sim = cosine_similarity (model (input ), trt_mod (input ))
34- assert (
37+ assertions . assertTrue (
3538 cos_sim > COSINE_THRESHOLD ,
36- f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
39+ msg = f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
3740 )
3841
3942 # Clean up model env
@@ -61,9 +64,9 @@ def test_mobilenet_v2(ir):
6164
6265 trt_mod = torchtrt .compile (model , ** compile_spec )
6366 cos_sim = cosine_similarity (model (input ), trt_mod (input ))
64- assert (
67+ assertions . assertTrue (
6568 cos_sim > COSINE_THRESHOLD ,
66- f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
69+ msg = f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
6770 )
6871
6972 # Clean up model env
@@ -91,9 +94,9 @@ def test_efficientnet_b0(ir):
9194
9295 trt_mod = torchtrt .compile (model , ** compile_spec )
9396 cos_sim = cosine_similarity (model (input ), trt_mod (input ))
94- assert (
97+ assertions . assertTrue (
9598 cos_sim > COSINE_THRESHOLD ,
96- f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
99+ msg = f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
97100 )
98101
99102 # Clean up model env
@@ -134,9 +137,9 @@ def test_bert_base_uncased(ir):
134137 for key in model_outputs .keys ():
135138 out , trt_out = model_outputs [key ], trt_model_outputs [key ]
136139 cos_sim = cosine_similarity (out , trt_out )
137- assert (
140+ assertions . assertTrue (
138141 cos_sim > COSINE_THRESHOLD ,
139- f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
142+ msg = f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
140143 )
141144
142145 # Clean up model env
@@ -164,9 +167,9 @@ def test_resnet18_half(ir):
164167
165168 trt_mod = torchtrt .compile (model , ** compile_spec )
166169 cos_sim = cosine_similarity (model (input ), trt_mod (input ))
167- assert (
170+ assertions . assertTrue (
168171 cos_sim > COSINE_THRESHOLD ,
169- f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
172+ msg = f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
170173 )
171174
172175 # Clean up model env
0 commit comments