File tree Expand file tree Collapse file tree 2 files changed +4
-5
lines changed Expand file tree Collapse file tree 2 files changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -711,7 +711,10 @@ def create_pytorch_engine(
711711 pt_model = None
712712
713713 if not sharding_config :
714- sharding_config = os .path .join ("default_shardings" , model_name + ".yaml" )
714+ sharding_file_name = "llama" if model_name .startswith ("llama" ) else "gemma"
715+ sharding_config = os .path .join (
716+ "default_shardings" , sharding_file_name + ".yaml"
717+ )
715718
716719 env_data = JetEngineEnvironmentData (
717720 tokenizer_path = tokenizer_path ,
Original file line number Diff line number Diff line change @@ -105,10 +105,6 @@ def main(argv: Sequence[str]):
105105 devices = server_lib .get_devices ()
106106 print (f"devices: { devices } " )
107107 sharding_config_path = _SHARDING_CONFIG .value
108- if not sharding_config_path :
109- sharding_config_path = os .path .join (
110- "default_shardings" , _MODEL_NAME .value + ".yaml"
111- )
112108 engine = jetstream_pt .create_pytorch_engine (
113109 devices = devices ,
114110 tokenizer_path = _TOKENIZER_PATH .value ,
You can’t perform that action at this time.
0 commit comments