3434from jetstream_pt import quantize
3535from jetstream_pt .environment import JetEngineEnvironment , JetEngineEnvironmentData
3636from jetstream_pt .third_party .llama import model_exportable , model_args
37+ from jetstream_pt .third_party .gemma import config as gemma_config , model as gemma_model
3738
3839
3940Mesh = jax .sharding .Mesh
@@ -108,32 +109,6 @@ def __init__(
108109 # out_shardings=self.get_decode_state_sharding())
109110 self ._lock = threading .RLock ()
110111
111- # pylint: disable-next=all
112- def sharding_by_name (self , name ):
113-
114- # This allows easier way to edit shardings
115- """
116- for key, val in self.env._data.experimental_sharding_axis_override.items():
117- if name.endswith(key):
118- return self.env.sharding_by_axis(val)
119- """
120-
121- if "weight_scaler" in name :
122- return self .x_sharding
123- if "tok_embeddings." in name :
124- return self .y_sharding
125- if "attention." in name :
126- if "wo" in name :
127- return self .y_sharding
128- return self .x_sharding
129- if "feed_forward." in name :
130- if "w2" in name :
131- return self .y_sharding
132- return self .x_sharding
133- if "output" in name :
134- return self .x_sharding
135- return self .replicated
136-
137112 # pylint: disable-next=all
138113 def init_decode_state (
139114 self ,
@@ -561,7 +536,7 @@ def _load_from_safetensors(self, path):
561536 for key , model_weights in self .pt_model .state_dict ().items ():
562537 if key == "freqs_cis" :
563538 continue
564- arr = jax .device_put (f .get_tensor (key ), self .sharding_by_name (key ))
539+ arr = jax .device_put (f .get_tensor (key ), self .env . sharding_by_name (key ))
565540 assert tuple (model_weights .shape ) == tuple (
566541 arr .shape
567542 ), f"key: { key } error: { model_weights .shape } != { arr .shape } "
@@ -587,7 +562,7 @@ def load_params(self) -> Params:
587562 else :
588563 jax_weights = self ._make_state_dict_jax (self .pt_model .state_dict ())
589564 jax_weights = {
590- key : jax .device_put (value , self .sharding_by_name (key ))
565+ key : jax .device_put (value , self .env . sharding_by_name (key ))
591566 for key , value in jax_weights .items ()
592567 }
593568 for k , v in jax_weights .items ():
@@ -664,6 +639,7 @@ def create_pytorch_engine(
664639 quantize_weights = False ,
665640 quantize_kv = False ,
666641 max_cache_length = 1024 ,
642+ sharding_config = None ,
667643) -> PyTorchEngine :
668644 """Returns: The pytorch engine."""
669645
@@ -706,42 +682,58 @@ def create_pytorch_engine(
706682 tokenizer = token_utils .load_vocab (tokenizer_path )
707683 pt_model = None
708684
685+ env_data = JetEngineEnvironmentData (
686+ tokenizer_path = tokenizer_path ,
687+ checkpoint_path = checkpoint_path ,
688+ checkpoint_format = checkpoint_format ,
689+ batch_size = batch_size ,
690+ max_decode_length = max_decode_length ,
691+ max_input_sequence_length = context_length ,
692+ enable_weight_quantization = quantize_weights ,
693+ enable_kv_quantization = quantize_kv ,
694+ cache_sequence_length = max_cache_length ,
695+ bf16_enable = bf16_enable ,
696+ sharding_config_path = sharding_config ,
697+ )
698+
709699 if model_name .startswith ("llama" ):
710700
711701 args = model_args .get_model_args (
712702 model_name + "-" + param_size , context_length , batch_size , bf16_enable
713703 )
714704 args .device = "meta"
715705 args .quantize = quantize_weights
716- env_data = JetEngineEnvironmentData (
717- tokenizer_path = tokenizer_path ,
718- checkpoint_path = checkpoint_path ,
719- checkpoint_format = checkpoint_format ,
720- model_type = "llama-2-" + param_size ,
721- batch_size = batch_size ,
722- max_decode_length = max_decode_length ,
723- max_input_sequence_length = context_length ,
724- enable_weight_quantization = quantize_weights ,
725- enable_kv_quantization = quantize_kv ,
726- cache_sequence_length = max_cache_length ,
727- bf16_enable = bf16_enable ,
728- num_layers = args .n_layers ,
729- cache_shape = (
730- batch_size ,
731- args .n_kv_heads ,
732- max_cache_length ,
733- args .dim // args .n_heads ,
734- ),
706+ env_data .cache_shape = (
707+ batch_size ,
708+ args .n_kv_heads ,
709+ max_cache_length ,
710+ args .dim // args .n_heads ,
735711 )
712+ env_data .model_type = "llama-2-" + param_size
713+ env_data .num_layers = args .n_layers
736714 env = JetEngineEnvironment (env_data )
737715 pt_model = model_exportable .Transformer (args , env )
738-
739- num_params_size = 0
740- num_params = 0
741- for _ , v in pt_model .state_dict ().items ():
742- num_params += 1
743- num_params_size += np .prod (v .shape ) * (1 if v .dtype == torch .int8 else 2 )
744- print ("Number of param Gbytes:" , num_params_size / (1 << 30 ))
745- print ("Number of param: " , num_params )
716+ elif model_name == "gemma" :
717+ args = gemma_config .get_model_config (param_size )
718+ env_data .cache_shape = (
719+ batch_size ,
720+ args .num_key_value_heads ,
721+ max_cache_length ,
722+ args .head_dim ,
723+ )
724+ env_data .model_type = "gemma-" + param_size
725+ env_data .num_layers = args .num_hidden_layers
726+ env = JetEngineEnvironment (env_data )
727+ pt_model = gemma_model .GemmaModel (args , env )
728+ else :
729+ raise RuntimeError (f"Model with name { model_name } not found" )
730+
731+ num_params_size = 0
732+ num_params = 0
733+ for _ , v in pt_model .state_dict ().items ():
734+ num_params += 1
735+ num_params_size += np .prod (v .shape ) * (1 if v .dtype == torch .int8 else 2 )
736+ print ("Number of param Gbytes:" , num_params_size / (1 << 30 ))
737+ print ("Number of param: " , num_params )
746738
747739 return PyTorchEngine (pt_model = pt_model , env = env )
0 commit comments