Skip to content

Commit 34116ca

Browse files
committed
Merge branch 'main' into docs
2 parents 45d6d4d + bb03e87 commit 34116ca

7 files changed

+13
-9
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ def save(self, save_dir_path: str):
7676
:param save_dir_path: the path to save to
7777
"""
7878
dprint(f"saving validation info to {save_dir_path}")
79-
if not os.path.exists(save_dir_path):
80-
os.makedirs(save_dir_path)
79+
os.makedirs(save_dir_path, exist_ok=True)
8180

8281
for sentence_i, sentence in enumerate(self._validation_info_list):
8382
file_path = os.path.join(save_dir_path, f"{sentence_i}.pt")
@@ -193,9 +192,11 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
193192
max_seq_len = model.config.max_expected_seq_len
194193

195194
# Add only_last_token optimization
196-
extra_generation_kwargs = {**padding_kwargs, "attn_algorithm": attn_algorithm}
195+
extra_generation_kwargs = {**padding_kwargs}
197196
if only_last_token:
198197
extra_generation_kwargs["only_last_token"] = only_last_token
198+
if attn_algorithm is not None:
199+
extra_generation_kwargs["attn_algorithm"] = attn_algorithm
199200

200201
result = generate(
201202
model,

tests/models/test_decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
327327

328328
distributed_kwargs = {}
329329
if USE_DISTRIBUTED:
330-
distributed_kwargs["distr_param"] = "tp"
330+
distributed_kwargs["distributed_strategy"] = "tp"
331331
distributed_kwargs["group"] = dist.group.WORLD
332332

333333
get_model_kwargs = {}

tests/models/test_model_expectations.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@
1919
model_dir = os.environ.get("FMS_TESTING_MODEL_DIR", "/tmp/models")
2020
LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
2121
GRANITE_3p2_8B_INSTRUCT = "ibm-granite/granite-3.2-8b-instruct"
22+
GRANITE_GUARDIAN_3p1_8B = "ibm-granite/granite-guardian-3.1-8b"
2223
ROBERTA_SQUAD_v2 = "deepset/roberta-base-squad2"
23-
torch.manual_seed(42)
2424

25-
micro_models = {LLAMA_3p1_8B_INSTRUCT, GRANITE_3p2_8B_INSTRUCT}
25+
micro_models = {LLAMA_3p1_8B_INSTRUCT, GRANITE_3p2_8B_INSTRUCT, GRANITE_GUARDIAN_3p1_8B}
2626

2727

2828
class AIUModelFixtureMixin(ModelFixtureMixin):
2929
@pytest.fixture(scope="class", autouse=True)
3030
def uninitialized_model(self, model_id):
31+
torch.manual_seed(42)
3132
if model_id in micro_models:
3233
get_model_kwargs = {"architecture": "hf_configured", "nlayers": 3}
3334
else:
@@ -52,7 +53,7 @@ def model(self, uninitialized_model):
5253
return uninitialized_model
5354

5455

55-
decoder_models = [LLAMA_3p1_8B_INSTRUCT, GRANITE_3p2_8B_INSTRUCT]
56+
decoder_models = [LLAMA_3p1_8B_INSTRUCT, GRANITE_3p2_8B_INSTRUCT, GRANITE_GUARDIAN_3p1_8B]
5657

5758

5859
class TestAIUDecoderModels(
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
12.65625,12.65625,12.59375,12.625,8.78125,37.96875,14.5625,16.90625,5.0,13.4375,14.71875,20.6875,11.03125,26.15625,39.53125,8.1875,7.0625,35.03125,6.28125,5.1875,13.25,5.15625,12.96875,8.65625,6.96875,19.375,7.21875,15.78125,14.53125,29.40625,8.5625,9.0625,8.5,1.375,16.21875,18.90625,20.34375,13.8125,8.53125,7.75,16.375,17.96875,7.1875,10.65625,11.625,56.15625,11.96875,5.3125,12.21875,4.1875,7.0625,0.0,10.34375,17.3125,32.84375,40.65625,40.78125,12.84375,8.4375,10.53125,8.5,9.125,8.625,14.34375
1+
10.71875,10.71875,10.65625,10.6875,6.84375,36.53125,12.1875,13.65625,0.09375,10.03125,15.71875,4.28125,11.3125,15.125,32.15625,13.59375,5.5625,13.125,3.375,4.625,19.90625,10.65625,0.3125,11.0,4.6875,9.8125,9.3125,9.65625,26.65625,12.59375,9.53125,41.90625,0.0625,0.9375,10.625,24.25,15.46875,10.96875,8.1875,7.25,3.15625,10.75,15.625,17.5,3.96875,32.78125,2.5625,1.46875,7.84375,0.0,19.0625,2.3125,12.5,6.78125,20.4375,10.34375,39.96875,21.625,3.8125,7.84375,9.96875,3.1875,7.75,5.15625
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.010009765625,0.010009765625,0.009521484375,0.009521484375,0.0234375,0.024658203125,0.014892578125,0.032958984375,0.01611328125,0.00732421875,0.054443359375,0.011474609375,0.013427734375,0.03173828125,0.016357421875,0.015869140625,0.022705078125,0.0205078125,0.025390625,0.017333984375,0.017333984375,0.005615234375,0.012451171875,0.002685546875,0.04296875,0.011962890625,0.017822265625,0.032470703125,0.00244140625,0.025390625,0.013671875,0.07177734375,0.035888671875,0.026611328125,0.0263671875,0.021240234375,0.0263671875,0.007080078125,0.02978515625,0.033203125,0.028564453125,0.031982421875,0.01318359375,0.0263671875,0.0166015625,0.00927734375,0.04345703125,0.028564453125,0.01416015625,0.041748046875,0.0185546875,0.01611328125,0.0166015625,0.0341796875,0.01220703125,0.0,0.01611328125,0.017578125,0.0146484375,0.031005859375,0.021484375,0.02978515625,0.006103515625,0.032470703125
1+
0.041748046875,0.041748046875,0.04150390625,0.041748046875,0.03662109375,0.037109375,0.020263671875,0.01904296875,0.0029296875,0.0341796875,0.04248046875,0.009033203125,0.013427734375,0.03857421875,0.0380859375,0.0498046875,0.033203125,0.004638671875,0.033203125,0.009765625,0.0107421875,0.015380859375,0.05078125,0.026123046875,0.017578125,0.0302734375,0.024658203125,0.011962890625,0.0361328125,0.068359375,0.03662109375,0.0322265625,0.03759765625,0.023681640625,0.051025390625,0.00830078125,0.032958984375,0.017578125,0.043212890625,0.022216796875,0.01513671875,0.036376953125,0.037841796875,0.015625,0.030517578125,0.060302734375,0.0283203125,0.0361328125,0.0,0.01953125,0.013916015625,0.026611328125,0.02197265625,0.004638671875,0.021484375,0.02001953125,0.010986328125,0.018798828125,0.020263671875,0.031982421875,0.03515625,0.014404296875,0.005615234375,0.037109375
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0.041748046875,0.041748046875,0.04150390625,0.041748046875,0.03662109375,0.037109375,0.020263671875,0.01904296875,0.0029296875,0.0341796875,0.04248046875,0.009033203125,0.013427734375,0.03857421875,0.0380859375,0.0498046875,0.033203125,0.004638671875,0.033203125,0.009765625,0.0107421875,0.015380859375,0.05078125,0.026123046875,0.017578125,0.0302734375,0.024658203125,0.011962890625,0.0361328125,0.068359375,0.03662109375,0.0322265625,0.03759765625,0.023681640625,0.051025390625,0.00830078125,0.032958984375,0.017578125,0.043212890625,0.022216796875,0.01513671875,0.036376953125,0.037841796875,0.015625,0.030517578125,0.060302734375,0.0283203125,0.0361328125,0.0,0.01953125,0.013916015625,0.026611328125,0.02197265625,0.004638671875,0.021484375,0.02001953125,0.010986328125,0.018798828125,0.020263671875,0.031982421875,0.03515625,0.014404296875,0.005615234375,0.037109375
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
base_model.dec_norm.weight,base_model.embedding.weight,base_model.layers.0.attn.dense.weight,base_model.layers.0.attn.in_proj.key.weight,base_model.layers.0.attn.in_proj.query.weight,base_model.layers.0.attn.in_proj.value.weight,base_model.layers.0.ff_ln.weight,base_model.layers.0.ff_sub_layer.w1.weight,base_model.layers.0.ff_sub_layer.w2.weight,base_model.layers.0.ff_sub_layer.wg.weight,base_model.layers.0.ln.weight,base_model.layers.1.attn.dense.weight,base_model.layers.1.attn.in_proj.key.weight,base_model.layers.1.attn.in_proj.query.weight,base_model.layers.1.attn.in_proj.value.weight,base_model.layers.1.ff_ln.weight,base_model.layers.1.ff_sub_layer.w1.weight,base_model.layers.1.ff_sub_layer.w2.weight,base_model.layers.1.ff_sub_layer.wg.weight,base_model.layers.1.ln.weight,base_model.layers.2.attn.dense.weight,base_model.layers.2.attn.in_proj.key.weight,base_model.layers.2.attn.in_proj.query.weight,base_model.layers.2.attn.in_proj.value.weight,base_model.layers.2.ff_ln.weight,base_model.layers.2.ff_sub_layer.w1.weight,base_model.layers.2.ff_sub_layer.w2.weight,base_model.layers.2.ff_sub_layer.wg.weight,base_model.layers.2.ln.weight,head.weight

0 commit comments

Comments
 (0)