File tree Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change 4040flags .DEFINE_bool ("enable_model_warmup" , False , "enable model warmup" )
4141
4242
43-
4443def shard_weights (env , weights , weight_shardings ):
4544 """Shard weights according to weight_shardings"""
4645 sharded = {}
Original file line number Diff line number Diff line change @@ -168,6 +168,10 @@ def _load_weights(directory):
168168 for key in f .keys ():
169169 state_dict [key ] = f .get_tensor (key ).to (torch .bfloat16 )
170170 # Load the state_dict into the model
171+ if not state_dict :
172+ raise AssertionError (
173+ f"Tried to load weights from { directory } , but couldn't find any."
174+ )
171175 return state_dict
172176
173177
@@ -186,7 +190,8 @@ def instantiate_model_from_repo_id(
186190 """Create model instance by hf model id.+"""
187191 model_dir = _hf_dir (repo_id )
188192 if not FLAGS .internal_use_random_weights and (
189- not os .path .exists (model_dir ) or not os .listdir (model_dir )
193+ not os .path .exists (model_dir )
194+ or not glob .glob (os .path .join (model_dir , "*.safetensors" ))
190195 ):
191196 # no weights has been downloaded
192197 _hf_download (repo_id , model_dir , FLAGS .hf_token )
You can’t perform that action at this time.
0 commit comments