@@ -194,6 +194,34 @@ def test_compile(self):
194194 msg = f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
195195 )
196196
197+ def test_compile_full_compilation (self ):
198+ self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
199+ self .model = (
200+ torch .jit .load (MODULE_DIR + "/tuple_input_output_scripted.jit.pt" )
201+ .eval ()
202+ .to ("cuda" )
203+ )
204+
205+ compile_spec = {
206+ "input_signature" : (
207+ (torchtrt .Input (self .input .shape ), torchtrt .Input (self .input .shape )),
208+ ),
209+ "device" : torchtrt .Device ("gpu:0" ),
210+ "enabled_precisions" : {torch .float },
211+ "min_block_size" : 1 ,
212+ "require_full_compilation" : True ,
213+ }
214+
215+ trt_mod = torchtrt .ts .compile (self .model , ** compile_spec )
216+ trt_out = trt_mod ((self .input , self .input ))
217+ pyt_out = self .model ((self .input , self .input ))
218+ for (t , p ) in zip (trt_out , pyt_out ):
219+ cos_sim = cosine_similarity (t , p )
220+ self .assertTrue (
221+ cos_sim > COSINE_THRESHOLD ,
222+ msg = f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
223+ )
224+
197225
198226class TestListInputOutput (unittest .TestCase ):
199227 def test_compile (self ):
@@ -225,6 +253,36 @@ def test_compile(self):
225253 msg = f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
226254 )
227255
256+ def test_compile_full_compilation (self ):
257+
258+ self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
259+ self .model = (
260+ torch .jit .load (MODULE_DIR + "/list_input_output_scripted.jit.pt" )
261+ .eval ()
262+ .to ("cuda" )
263+ )
264+
265+ compile_spec = {
266+ "input_signature" : (
267+ [torchtrt .Input (self .input .shape ), torchtrt .Input (self .input .shape )],
268+ ),
269+ "device" : torchtrt .Device ("gpu:0" ),
270+ "enabled_precisions" : {torch .float },
271+ "min_block_size" : 1 ,
272+ "require_full_compilation" : True ,
273+ }
274+
275+ trt_mod = torchtrt .ts .compile (self .model , ** compile_spec )
276+ trt_out = trt_mod ((self .input , self .input ))
277+ pyt_out = self .model ((self .input , self .input ))
278+
279+ for (t , p ) in zip (trt_out , pyt_out ):
280+ cos_sim = cosine_similarity (t , p )
281+ self .assertTrue (
282+ cos_sim > COSINE_THRESHOLD ,
283+ msg = f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
284+ )
285+
228286
229287class TestListInputTupleOutput (unittest .TestCase ):
230288 def test_compile (self ):
@@ -255,6 +313,35 @@ def test_compile(self):
255313 msg = f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
256314 )
257315
316+ def test_compile_full_compilation (self ):
317+
318+ self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
319+ self .model = (
320+ torch .jit .load (MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt" )
321+ .eval ()
322+ .to ("cuda" )
323+ )
324+
325+ compile_spec = {
326+ "input_signature" : (
327+ [torchtrt .Input (self .input .shape ), torchtrt .Input (self .input .shape )],
328+ ),
329+ "device" : torchtrt .Device ("gpu:0" ),
330+ "enabled_precisions" : {torch .float },
331+ "min_block_size" : 1 ,
332+ "require_full_compilation" : True ,
333+ }
334+
335+ trt_mod = torchtrt .ts .compile (self .model , ** compile_spec )
336+ trt_out = trt_mod ((self .input , self .input ))
337+ pyt_out = self .model ((self .input , self .input ))
338+ for (t , p ) in zip (trt_out , pyt_out ):
339+ cos_sim = cosine_similarity (t , p )
340+ self .assertTrue (
341+ cos_sim > COSINE_THRESHOLD ,
342+ msg = f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
343+ )
344+
258345
259346if __name__ == "__main__" :
260347 unittest .main ()
0 commit comments