Skip to content

Commit 99e6bd1

Browse files
authored
Merge pull request #148 from kcirred/prefix_rewrite
Prefix rewrite
2 parents abe35d3 + fcf950f commit 99e6bd1

File tree

5 files changed

+23
-24
lines changed

5 files changed

+23
-24
lines changed

aiu_fms_testing_utils/testing/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def format_kwargs_to_string(**kwargs):
1717
# only append if formatted_value exists
1818
if formatted_value:
1919
# Keep previous convention of variable names with `-` instead of `_`
20-
formatted_pairs.append(f"{key.replace('_', '-')}-{formatted_value}")
20+
formatted_pairs.append(
21+
f"{key.replace('_', '-')}-{formatted_value.replace('/', '--')}"
22+
)
2123

2224
return "_".join(formatted_pairs)

aiu_fms_testing_utils/testing/validation.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,6 @@ def __len__(self):
128128

129129

130130
def get_default_validation_prefix(
131-
model_id: str,
132-
max_new_tokens: int,
133-
batch_size: int,
134-
seq_length: int,
135-
dtype: str,
136-
attn_type: str,
137-
aftu_version: str,
138131
**kwargs,
139132
):
140133
"""
@@ -150,12 +143,12 @@ def get_default_validation_prefix(
150143
Returns:
151144
str: A hashed prefix that will be prepended to the file name
152145
"""
146+
aftu_version = kwargs.pop(
147+
"aftu_version", ".".join([str(_) for _ in version_tuple[:3]])
148+
)
153149
kwargs_str = format_kwargs_to_string(**kwargs)
154150

155-
if kwargs_str == "":
156-
filename = f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}"
157-
else:
158-
filename = f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}_{kwargs_str}"
151+
filename = f"{kwargs_str}"
159152
hash_object = hashlib.sha256(filename.encode("utf-8"))
160153
hex_digest = hash_object.hexdigest()
161154
return f"{hex_digest}_{aftu_version}"
@@ -435,7 +428,7 @@ def get_validation_info_path(
435428

436429
sample_key = kwargs.get("sample_key", None)
437430

438-
validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out"
431+
validation_file_name = f"{get_default_validation_prefix(aftu_version='.'.join([str(_) for _ in aftu_version[:3]]), model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, sample_key=sample_key)}.{device_type}_validation_info.{seed}.out"
439432
full_path = os.path.join(validation_info_dir, validation_file_name)
440433
return full_path
441434

scripts/generate_layers_metrics.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,11 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
473473
cos_sim = tensor_cos_sim(tensor_cpu_out, cuda_output)
474474

475475
prefix = get_default_validation_prefix(
476-
model_path, max_new_token, batch_size, seq_length, "float16"
476+
model_id=model_path,
477+
max_new_tokens=max_new_token,
478+
batch_size=batch_size,
479+
seq_length=seq_length,
480+
dtype="float16",
477481
)
478482
layer_name = str(layer_key).replace("[", "").replace("]", "")
479483

scripts/generate_metrics.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,11 @@
134134

135135
# this follows the same pattern of naming in test_shapes. This way we can save and re-use for quicker shape testing.
136136
prefix = get_default_validation_prefix(
137-
args.variant,
138-
args.max_new_tokens,
139-
args.batch_size,
140-
args.min_pad_length,
141-
args.default_dtype,
137+
model_id=args.variant,
138+
max_new_tokens=args.max_new_tokens,
139+
batch_size=args.batch_size,
140+
seq_len=args.min_pad_length,
141+
dtype=args.default_dtype,
142142
)
143143
if os.path.exists(os.path.join(args.output_dir, f"{prefix}.prob_mean.csv")):
144144
print("skipping metric generation as it has already been done")

tests/testing/test_validation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_validation_info_round_trip(validation_type, post_iteration_hook):
8080

8181

8282
def test_get_validation_info_path(tmp_path):
83-
check_pathname = "ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa"
83+
check_pathname = "attn-type-sdpa_batch-size-4_dtype-fp16_max-new-tokens-128_model-id-ibm-granite--granite-3.3-8b-instruct_seq-length-64"
8484
hash_object = hashlib.sha256(check_pathname.encode("utf-8"))
8585
hex_digest = hash_object.hexdigest()
8686

@@ -91,7 +91,7 @@ def test_get_validation_info_path(tmp_path):
9191
== f"{tmp_path}/{hex_digest}_{'.'.join([str(_) for _ in version_tuple[:3]])}.cpu_validation_info.0.out"
9292
)
9393

94-
check_pathname = "ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa"
94+
check_pathname = "attn-type-sdpa_batch-size-4_dtype-fp16_max-new-tokens-128_model-id-ibm-granite--granite-3.3-8b-instruct_seq-length-64"
9595
hash_object = hashlib.sha256(check_pathname.encode("utf-8"))
9696
hex_digest = hash_object.hexdigest()
9797

@@ -295,15 +295,15 @@ def test_get_default_validation_prefix(
295295

296296
sample_key = None
297297
# get_default_validation_prefix with sample_key set to None
298-
check_prefix_sample_key_none = f"{model_variant}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}"
298+
check_prefix_sample_key_none = f"attn-type-{attn_type}_batch-size-{batch_size}_dtype-{dtype}_max-new-tokens-{max_new_tokens}_model-id-{model_variant}_seq-length-{seq_length}"
299299
hash_object = hashlib.sha256(check_prefix_sample_key_none.encode("utf-8"))
300300
hex_digest = hash_object.hexdigest()
301-
prefix_sample_key_none = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out"
301+
prefix_sample_key_none = f"{get_default_validation_prefix(model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, aftu_version='.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out"
302302

303303
assert prefix_sample_key_none == f"{hex_digest}_1.2.3.cpu_validation_info.0.out"
304304

305305
# get_default_validation_prefix with no kwargs using legacy case
306-
legacy_prefix = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out"
306+
legacy_prefix = f"{get_default_validation_prefix(model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, aftu_version='.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out"
307307
assert prefix_sample_key_none == legacy_prefix
308308

309309
# retrieve a sample_key with return_key is True

0 commit comments

Comments
 (0)