2222import torch
2323import torch_xla2
2424from torch .utils import _pytree as pytree
25- from . import helpers
2625
2726
2827from jetstream_pt .engine import PyTorchEngine
2928from jetstream_pt .third_party .llama import model_exportable , model_args
3029from jetstream_pt .third_party .llama .generation_original import LlamaOriginal
3130from jetstream_pt import environment
31+ from tests import helpers
3232
3333
3434class LlamaE2ETest (unittest .TestCase ):
@@ -93,9 +93,8 @@ def test_jetstream_llama2_seed(self):
9393 jax .config .update ("jax_platform_name" , "cpu" )
9494 print (f"---------> { jax .devices ()} " )
9595
96- torch .set_default_dtype (torch .bfloat16 )
9796 # pylint: disable-next=all
98- env , model_arg = helpers .make_env_tiny ()
97+ env , model_arg = helpers .make_env_tiny (bf16_enable = True )
9998 # pylint: disable-next=all
10099 tokens = np .arange (10 , dtype = np .int32 )
101100 true_length = tokens .shape [- 1 ]
@@ -221,7 +220,6 @@ def test_llama_e2e_float32(self):
221220 print (f"---------> { jax .devices ()} " )
222221
223222 env , model_arg = helpers .make_env_tiny (bf16_enable = False )
224- torch .set_default_dtype (torch .float32 )
225223 out_tokens , expected_output_tokens = self ._llama_e2e (env , model_arg )
226224 self .assertEqual (out_tokens , expected_output_tokens )
227225
@@ -232,7 +230,6 @@ def test_llama_e2e_bfloat16(self):
232230 print (f"---------> { jax .devices ()} " )
233231
234232 env , model_arg = helpers .make_env_tiny (bf16_enable = True )
235- torch .set_default_dtype (torch .bfloat16 )
236233 out_tokens , expected_output_tokens = self ._llama_e2e (env , model_arg )
237234 self .assertNotEqual (out_tokens , expected_output_tokens )
238235
@@ -242,9 +239,8 @@ def test_llama_e2e_two_addtional_tokens(self):
242239 jax .config .update ("jax_platform_name" , "cpu" )
243240 print (f"---------> { jax .devices ()} " )
244241
245- torch .set_default_dtype (torch .bfloat16 )
246242 # pylint: disable-next=all
247- env , model_arg = helpers .make_env_tiny ()
243+ env , model_arg = helpers .make_env_tiny (bf16_enable = True )
248244 # pylint: disable-next=all
249245 tokens = np .arange (10 , dtype = np .int32 )
250246 tokens = np .append (tokens , [15050 , 3503 ], axis = - 1 )
@@ -315,9 +311,8 @@ def test_llama_e2e_four_addtional_tokens(self):
315311 jax .config .update ("jax_platform_name" , "cpu" )
316312 print (f"---------> { jax .devices ()} " )
317313
318- torch .set_default_dtype (torch .bfloat16 )
319314 # pylint: disable-next=all
320- env , model_arg = helpers .make_env_tiny ()
315+ env , model_arg = helpers .make_env_tiny (bf16_enable = True )
321316 # pylint: disable-next=all
322317 tokens = np .arange (10 , dtype = np .int32 )
323318 tokens = np .append (tokens , [15050 , 3503 , 11833 , 28551 ], axis = - 1 )
@@ -387,7 +382,6 @@ def test_llama_with_original_prefill_decode_32(self):
387382 print (f"---------> { jax .devices ()} " )
388383
389384 env , model_arg = helpers .make_env_tiny (bf16_enable = False )
390- torch .set_default_dtype (torch .float32 )
391385 # pylint: disable-next=all
392386 tokens = np .arange (10 , dtype = np .int32 )
393387 true_length = tokens .shape [- 1 ]
@@ -458,12 +452,11 @@ def test_llama_with_original_prefill_decode_32(self):
458452
459453 # pylint: disable-next=all
460454 def test_llama_with_original_prefill_decode (self ):
461- """test jetstream llama by comparing original prefill and decode steps with float32 """
455+ """test jetstream llama by comparing original prefill and decode steps with bf16 """
462456 jax .config .update ("jax_platform_name" , "cpu" )
463457 print (f"---------> { jax .devices ()} " )
464458
465- torch .set_default_dtype (torch .bfloat16 )
466- env , model_arg = helpers .make_env_tiny ()
459+ env , model_arg = helpers .make_env_tiny (bf16_enable = True )
467460 # pylint: disable-next=all
468461 tokens = np .arange (10 , dtype = np .int32 )
469462 true_length = tokens .shape [- 1 ]
0 commit comments