Skip to content

Commit ec66526

Browse files
authored
Raise error if weights are not loaded (#206)
1 parent bb174b6 commit ec66526

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

jetstream_pt/cli.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
flags.DEFINE_bool("enable_model_warmup", False, "enable model warmup")
4141

4242

43-
4443
def shard_weights(env, weights, weight_shardings):
4544
"""Shard weights according to weight_shardings"""
4645
sharded = {}

jetstream_pt/fetch_models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ def _load_weights(directory):
168168
for key in f.keys():
169169
state_dict[key] = f.get_tensor(key).to(torch.bfloat16)
170170
# Load the state_dict into the model
171+
if not state_dict:
172+
raise AssertionError(
173+
f"Tried to load weights from {directory}, but couldn't find any."
174+
)
171175
return state_dict
172176

173177

@@ -186,7 +190,8 @@ def instantiate_model_from_repo_id(
186190
"""Create model instance by hf model id.+"""
187191
model_dir = _hf_dir(repo_id)
188192
if not FLAGS.internal_use_random_weights and (
189-
not os.path.exists(model_dir) or not os.listdir(model_dir)
193+
not os.path.exists(model_dir)
194+
or not glob.glob(os.path.join(model_dir, "*.safetensors"))
190195
):
191196
# no weights has been downloaded
192197
_hf_download(repo_id, model_dir, FLAGS.hf_token)

0 commit comments

Comments
 (0)