Skip to content

Commit b9068c6

Browse files
authored
Merge pull request #9 from foundation-model-stack/validation_criteria
Robust level 1 validation testing
2 parents df7aca2 + aa25fc6 commit b9068c6

File tree

4 files changed

+383
-155
lines changed

4 files changed

+383
-155
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,14 +286,15 @@ def capture_level_1_metrics(reference_logits_per_sentence, test_logits_per_sente
286286

287287
return loss_metrics
288288

289-
def filter_failed_level_1_cases(level_1_loss_metrics, fail_f):
289+
def filter_failed_level_1_cases(level_1_loss_metrics, fail_f, print_failed=False):
290290
failed_cases = []
291291
for (sentence_idx, token_idx, metrics_value) in level_1_loss_metrics:
292292
if fail_f(metrics_value):
293293
failed_cases.append((sentence_idx, token_idx, metrics_value))
294-
print(
295-
f"In sentence {sentence_idx+1}, the metric for token {token_idx} is {metrics_value}"
296-
)
294+
if print_failed:
295+
dprint(
296+
f"In sentence {sentence_idx+1}, the metric for token {token_idx} is {metrics_value}"
297+
)
297298
return failed_cases
298299

299300

@@ -304,4 +305,4 @@ def print_failed_cases(failed_cases, aiu_tokens, validation_tokens, tokenizer):
304305

305306
aiu_str = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(aiu_token))
306307
validation_str = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(validation_token))
307-
print(f"In sentence {sentence_index+1}/{len(aiu_tokens)}, token {token_index}, AIU outputs {aiu_token} instead of {validation_token} -- AIU val={aiu_str} != CPU val={validation_str}")
308+
print(f"In sentence {sentence_index+1}/{len(aiu_tokens)}, token {token_index}, AIU outputs {aiu_token} instead of {validation_token} -- AIU val={aiu_str} -- CPU val={validation_str}")

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
import torch
22
import torch.nn as nn
33
import time
4+
from fms.utils.tokenizers import BaseTokenizer
45
from fms.utils.generation import generate
56
from aiu_fms_testing_utils.utils.aiu_setup import dprint
7+
from typing import Optional, List, Tuple
8+
import os
9+
import requests
10+
import json
11+
import random
612

713
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, **padding_kwargs):
814
from torch_sendnn import torch_sendnn
@@ -25,4 +31,60 @@ def ids_for_prompt(prompt, tokenizer):
2531
if tokenizer.bos_token_id != tokenizer.eos_token_id:
2632
ids = [tokenizer.bos_token_id] + ids
2733
ids = torch.tensor(ids, dtype=torch.long, device="cpu")
28-
return ids
34+
return ids
35+
36+
def __download_file(url, filename):
37+
try:
38+
response = requests.get(url, stream=True)
39+
response.raise_for_status()
40+
41+
with open(filename, 'wb') as file:
42+
for chunk in response.iter_content(chunk_size=8192):
43+
file.write(chunk)
44+
print(f"Successfully downloaded {filename}")
45+
46+
except requests.exceptions.RequestException as e:
47+
print(f"An error occurred: {e}")
48+
49+
def sample_sharegpt_requests(
50+
dataset_path: str,
51+
num_requests: int,
52+
tokenizer: BaseTokenizer,
53+
prompt_length_min: int = 32,
54+
prompt_length_max: int = 64,
55+
seed: Optional[int] = None
56+
) -> List[Tuple[str, int]]:
57+
if not os.path.exists(dataset_path):
58+
print("downloading share-gpt dataset as it does not exist")
59+
__download_file("https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json", dataset_path)
60+
61+
# Load the dataset.
62+
with open(dataset_path, encoding='utf-8') as f:
63+
dataset = json.load(f)
64+
# Filter out the conversations with less than 2 turns.
65+
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
66+
# Only keep the first two turns of each conversation.
67+
dataset = [(data["conversations"][0]["value"],
68+
data["conversations"][1]["value"]) for data in dataset]
69+
70+
# Shuffle the dataset.
71+
if seed is not None:
72+
random.Random(seed).shuffle(dataset)
73+
74+
# Filter out sequences that are too long or too short
75+
filtered_dataset: List[Tuple[str, int, int]] = []
76+
for i in range(len(dataset)):
77+
if len(filtered_dataset) == num_requests:
78+
break
79+
80+
# Tokenize the prompts and completions.
81+
prompt = dataset[i][0]
82+
prompt_token_ids = ids_for_prompt(prompt, tokenizer)
83+
84+
prompt_len = len(prompt_token_ids)
85+
if prompt_len < prompt_length_min or prompt_len > prompt_length_max:
86+
# Prune too short or too long sequences.
87+
continue
88+
filtered_dataset.append((prompt, prompt_len))
89+
90+
return filtered_dataset

scripts/generate_metrics.py

Lines changed: 108 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import ast
23
import json
34
import os
45
import random
@@ -8,7 +9,7 @@
89

910
from aiu_fms_testing_utils.testing.validation import capture_level_1_metrics, extract_validation_information, LogitsExtractorHook, print_failed_cases, \
1011
validate_level_0, GoldenTokenHook, top_k_loss_calculator
11-
from aiu_fms_testing_utils.utils import ids_for_prompt
12+
from aiu_fms_testing_utils.utils import ids_for_prompt, sample_sharegpt_requests
1213
from fms.models import get_model
1314
from fms.utils import tokenizers
1415
from fms.utils.generation import pad_input_ids
@@ -84,10 +85,30 @@
8485
help="top k values per token to generate loss on",
8586
default=20
8687
)
88+
parser.add_argument(
89+
"--num_test_tokens_per_sequence",
90+
type=int,
91+
help="number of tokens in test. For instance, if max_new_tokens=128 and num_test_tokens_per_sequence=256, this means we will generate data over 2 sample prompts. If not set, will be set to max_new_tokens",
92+
default=None
93+
)
94+
parser.add_argument(
95+
"--extra_get_model_kwargs",
96+
nargs='*',
97+
default={},
98+
help="Use this to override model configuration values to get model. Example: --extra_get_model_kwargs nlayers=2,..."
99+
)
87100
args = parser.parse_args()
88101

102+
extra_get_model_kwargs = {}
103+
for a in args.extra_get_model_kwargs:
104+
a_split = a.split("=")
105+
try:
106+
extra_get_model_kwargs[a_split[0]] = ast.literal_eval(a_split[1])
107+
except ValueError:
108+
extra_get_model_kwargs[a_split[0]] = a_split[1]
89109

90-
prefix = f"{args.variant.replace('/', '--')}_max-new-tokens-{args.max_new_tokens}_batch-size-{args.batch_size}_seq-length{args.min_pad_length}_dtype-{args.default_dtype}"
110+
# this follows the same pattern of naming in test_shapes. This way we can save and re-use for quicker shape testing.
111+
prefix = f"{args.variant.replace('/', '--')}_max-new-tokens-{args.max_new_tokens}_batch-size-{args.batch_size}_seq-length-{args.min_pad_length}_dtype-{args.default_dtype}"
91112
if os.path.exists(os.path.join(args.output_dir, f"{prefix}.prob_mean.csv")):
92113
print("skipping metric generation as it has already been done")
93114
exit(0)
@@ -115,11 +136,11 @@
115136
model_path=args.model_path,
116137
device_type="cuda",
117138
data_type=default_dtype,
139+
**extra_get_model_kwargs,
118140
)
119141

120-
print("loaded cuda model")
121-
122142
cuda_model.eval()
143+
print("loaded cuda model")
123144

124145
# prepare the cpu model (this is the reference)
125146
cpu_model = get_model(
@@ -128,45 +149,11 @@
128149
model_path=args.model_path,
129150
device_type="cpu",
130151
data_type=torch.float32,
152+
**extra_get_model_kwargs,
131153
)
132154
cpu_model.eval()
133155
print("loaded cpu model")
134156

135-
def sample_sharegpt_requests(
136-
dataset_path: str,
137-
num_requests: int,
138-
tokenizer,
139-
) -> List[Tuple[str, int, int, None]]:
140-
# Load the dataset.
141-
with open(dataset_path, encoding='utf-8') as f:
142-
dataset = json.load(f)
143-
# Filter out the conversations with less than 2 turns.
144-
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
145-
# Only keep the first two turns of each conversation.
146-
dataset = [(data["conversations"][0]["value"],
147-
data["conversations"][1]["value"]) for data in dataset]
148-
149-
# Shuffle the dataset.
150-
random.Random(42).shuffle(dataset)
151-
152-
# Filter out sequences that are too long or too short
153-
filtered_dataset: List[Tuple[str, int, int]] = []
154-
for i in range(len(dataset)):
155-
if len(filtered_dataset) == num_requests:
156-
break
157-
158-
# Tokenize the prompts and completions.
159-
prompt = dataset[i][0]
160-
prompt_token_ids = ids_for_prompt(prompt, tokenizer)
161-
162-
prompt_len = len(prompt_token_ids)
163-
if prompt_len < 32 or prompt_len > args.min_pad_length:
164-
# Prune too short sequences.
165-
continue
166-
filtered_dataset.append((prompt, prompt_len))
167-
168-
return filtered_dataset
169-
170157
def find_eos_index(reference_tokens, eos_token_id):
171158
result = []
172159
for sentence in reference_tokens:
@@ -184,21 +171,17 @@ def filter_before_eos(l, filter_indexes):
184171
from itertools import groupby
185172
filtered_results = [list(g)[:filter_indexes[k]] for k, g in groupby(l, key=lambda x: x[0])]
186173
return [item for sublist in filtered_results for item in sublist]
187-
188-
prompts_and_lens = sample_sharegpt_requests(args.sharegpt_path, args.batch_size, tokenizer)
189-
print(f"prompt_lengths: {[pl[1] for pl in prompts_and_lens]}")
190-
prompts = [ids_for_prompt(pl[0], tokenizer) for pl in prompts_and_lens]
191174

192-
padding_length = args.min_pad_length
175+
def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
176+
prompts_and_sizes = sample_sharegpt_requests(args.sharegpt_path, batch_size, tokenizer, seq_length // 2, seq_length, seed)
177+
prompt_list = []
178+
for prompt, _ in prompts_and_sizes:
179+
prompt_list.append(ids_for_prompt(prompt, tokenizer))
193180

194-
has_padding = args.batch_size > 1 or padding_length != 0
195-
max_len = max([len(prompt) for prompt in prompts])
181+
input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
182+
return input_ids, padding_kwargs
196183

197-
if has_padding:
198-
ids, padding_kwargs = pad_input_ids(prompts, min_pad_length=padding_length)
199-
else:
200-
ids = prompts
201-
padding_kwargs = {}
184+
ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer)
202185

203186
# first test validation level 0
204187
cpu_validation_info = extract_validation_information(
@@ -231,63 +214,88 @@ def filter_before_eos(l, filter_indexes):
231214
if len(failed_responses) != 0:
232215
print_failed_cases(failed_responses, cpu_static_tokens, cuda_static_tokens, tokenizer)
233216

234-
# generate aiu validation info
235-
cuda_validation_info = extract_validation_information(
236-
cuda_model,
237-
ids.to("cuda"),
238-
args.max_new_tokens,
239-
GoldenTokenHook(cpu_static_tokens, "cuda"),
240-
only_last_token=True,
241-
**{k: v.to("cuda") for k,v in padding_kwargs.items()}
242-
)
243-
244-
print("extracted cuda validation information level 1")
245-
246-
cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()(r, t.softmax(dim=1).to(dtype=torch.float32))
247-
prob_mean = lambda r, t: torch.mean((r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32)) - 1.0)
248-
prob_std = lambda r, t: torch.std(r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32))
249-
diff_mean = lambda r, t: torch.mean(r.softmax(dim=1).to(dtype=torch.float32) - t.softmax(dim=1).to(dtype=torch.float32))
250-
251217
def write_csv(l, path, metric):
252218
with open(path, 'w') as f:
253219
f.write(f'{metric}\n')
254220
for t in l:
255221
f.write(f"{t[2].item()}\n")
256222
f.close()
257223

258-
prefix = f"{args.variant.replace('/', '--')}_max-new-tokens-{args.max_new_tokens}_batch-size-{args.batch_size}_seq-length{args.min_pad_length}_dtype-{args.default_dtype}"
259-
260-
cpu_validation_info.save(os.path.join(args.output_dir, f"{prefix}.cpu_output_logits.out"))
261-
cuda_validation_info.save(os.path.join(args.output_dir, f"{prefix}.cuda_output_logits.out"))
224+
num_test_tokens_per_sequence = args.num_test_tokens_per_sequence
225+
if num_test_tokens_per_sequence is None:
226+
num_test_tokens_per_sequence = args.max_new_tokens
262227

263-
level_1_metrics = capture_level_1_metrics(
264-
cpu_validation_info.get_info("logits"),
265-
cuda_validation_info.get_info("logits"),
266-
top_k_loss_calculator(args.topk_per_token, prob_mean),
267-
)
268-
loss_metrics = filter_before_eos(level_1_metrics, eos_indexes)
269-
write_csv(loss_metrics, os.path.join(args.output_dir, f"{prefix}.prob_mean.csv"), "prob_mean")
228+
cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()(r, t.softmax(dim=1).to(dtype=torch.float32))
229+
prob_mean = lambda r, t: torch.mean((r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32)) - 1.0)
230+
prob_std = lambda r, t: torch.std(r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32))
231+
diff_mean = lambda r, t: torch.mean(r.softmax(dim=1).to(dtype=torch.float32) - t.softmax(dim=1).to(dtype=torch.float32))
270232

271-
level_1_metrics = capture_level_1_metrics(
272-
cpu_validation_info.get_info("logits"),
273-
cuda_validation_info.get_info("logits"),
274-
top_k_loss_calculator(args.topk_per_token, prob_std),
275-
)
276-
loss_metrics = filter_before_eos(level_1_metrics, eos_indexes)
277-
write_csv(loss_metrics, os.path.join(args.output_dir, f"{prefix}.prob_std.csv"), "prob_std")
233+
prob_mean_metrics = []
234+
prob_std_metrics = []
235+
prob_diff_metrics = []
236+
prob_ce_loss_metrics = []
278237

279-
level_1_metrics = capture_level_1_metrics(
280-
cpu_validation_info.get_info("logits"),
281-
cuda_validation_info.get_info("logits"),
282-
top_k_loss_calculator(args.topk_per_token, cross_entropy),
283-
)
284-
loss_metrics = filter_before_eos(level_1_metrics, eos_indexes)
285-
write_csv(loss_metrics, os.path.join(args.output_dir, f"{prefix}.ce.csv"), "ce")
238+
prefix = f"{args.variant.replace('/', '--')}_max-new-tokens-{args.max_new_tokens}_batch-size-{args.batch_size}_seq-length{args.min_pad_length}_dtype-{args.default_dtype}"
286239

287-
level_1_metrics = capture_level_1_metrics(
288-
cpu_validation_info.get_info("logits"),
289-
cuda_validation_info.get_info("logits"),
290-
top_k_loss_calculator(args.topk_per_token, diff_mean),
291-
)
292-
loss_metrics = filter_before_eos(level_1_metrics, eos_indexes)
293-
write_csv(loss_metrics, os.path.join(args.output_dir, f"{prefix}.diff_mean.csv"), "diff_mean")
240+
for i in range(num_test_tokens_per_sequence // args.max_new_tokens):
241+
ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer, i)
242+
243+
# only need to compute this once if we aren't generating more test data
244+
if num_test_tokens_per_sequence > args.max_new_tokens:
245+
cpu_validation_info = extract_validation_information(
246+
cpu_model,
247+
ids,
248+
args.max_new_tokens,
249+
LogitsExtractorHook(),
250+
attn_algorithm="math",
251+
**padding_kwargs
252+
)
253+
eos_indexes = find_eos_index(cpu_validation_info.get_info("tokens"), tokenizer.eos_token_id)
254+
255+
# generate aiu validation info
256+
cuda_validation_info = extract_validation_information(
257+
cuda_model,
258+
ids.to("cuda"),
259+
args.max_new_tokens,
260+
GoldenTokenHook(cpu_validation_info.get_info("tokens"), "cuda"),
261+
only_last_token=True,
262+
**{k: v.to("cuda") for k,v in padding_kwargs.items()}
263+
)
264+
265+
print("extracted cuda validation information level 1")
266+
267+
cpu_validation_info.save(os.path.join(args.output_dir, f"{prefix}.cpu_validation_info.{i}.out"))
268+
cuda_validation_info.save(os.path.join(args.output_dir, f"{prefix}.cuda_validation_info.{i}.out"))
269+
270+
level_1_metrics = capture_level_1_metrics(
271+
cpu_validation_info.get_info("logits"),
272+
cuda_validation_info.get_info("logits"),
273+
top_k_loss_calculator(args.topk_per_token, prob_mean),
274+
)
275+
prob_mean_metrics.extend(filter_before_eos(level_1_metrics, eos_indexes))
276+
277+
level_1_metrics = capture_level_1_metrics(
278+
cpu_validation_info.get_info("logits"),
279+
cuda_validation_info.get_info("logits"),
280+
top_k_loss_calculator(args.topk_per_token, prob_std),
281+
)
282+
prob_std_metrics.extend(filter_before_eos(level_1_metrics, eos_indexes))
283+
284+
level_1_metrics = capture_level_1_metrics(
285+
cpu_validation_info.get_info("logits"),
286+
cuda_validation_info.get_info("logits"),
287+
top_k_loss_calculator(args.topk_per_token, cross_entropy),
288+
)
289+
prob_ce_loss_metrics.extend(filter_before_eos(level_1_metrics, eos_indexes))
290+
291+
level_1_metrics = capture_level_1_metrics(
292+
cpu_validation_info.get_info("logits"),
293+
cuda_validation_info.get_info("logits"),
294+
top_k_loss_calculator(args.topk_per_token, diff_mean),
295+
)
296+
prob_diff_metrics.extend(filter_before_eos(level_1_metrics, eos_indexes))
297+
298+
write_csv(prob_mean_metrics, os.path.join(args.output_dir, f"{prefix}.prob_mean.csv"), "prob_mean")
299+
write_csv(prob_std_metrics, os.path.join(args.output_dir, f"{prefix}.prob_std.csv"), "prob_std")
300+
write_csv(prob_ce_loss_metrics, os.path.join(args.output_dir, f"{prefix}.ce.csv"), "ce")
301+
write_csv(prob_diff_metrics, os.path.join(args.output_dir, f"{prefix}.diff_mean.csv"), "diff_mean")

0 commit comments

Comments
 (0)