Skip to content

Commit f7e988c

Browse files
committed
Basic script for CETT and fix phi3
Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent 7c286b1 commit f7e988c

File tree

3 files changed

+598
-68
lines changed

3 files changed

+598
-68
lines changed

src/cett.py

Lines changed: 199 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1+
from collections import defaultdict
2+
import logging
3+
import os
4+
import json
5+
import tqdm
6+
import argparse
17

2-
8+
from datasets import load_dataset
39
import torch
10+
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
11+
from transformers.trainer_utils import set_seed
412

513
from src.activation_capture import ActivationCapture, Hook
614

15+
# Setup logging
16+
logging.basicConfig(level=logging.INFO)
17+
logger = logging.getLogger(__name__)
18+
719
def calculate_threshold_one_token(neuron_outputs, cett_target, n_quantiles=1000):
820
norms = neuron_outputs.norm(dim=0)
921
quantiles = norms.quantile(torch.linspace(0,1,n_quantiles))
@@ -31,7 +43,7 @@ def find_threshold(model, dataloader, layer_idx, cett_target=0.2, n_quantiles=50
3143
model.activation_capture = model.ACTIVATION_CAPTURE(model)
3244
model.activation_capture.register_hooks(hooks=[Hook.UP])
3345

34-
thresholds = []
46+
thresholds = defaultdict(list)
3547

3648
with torch.no_grad():
3749
for batch in dataloader:
@@ -42,13 +54,190 @@ def find_threshold(model, dataloader, layer_idx, cett_target=0.2, n_quantiles=50
4254

4355
_ = model(input_ids=input_ids, attention_mask=attention_mask)
4456

45-
activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx]
46-
activations = activations.view(-1, activations.size(-1))
57+
for layer,layer_idx in enumerate(model.activation_capture.get_layers()):
58+
activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx]
59+
activations = activations.view(-1, activations.size(-1))
60+
61+
for i in range(activations.size(0)):
62+
neuron_outputs = activations[i] * layer.mlp.down_proj.weight
63+
threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles)
64+
thresholds[layer_idx].append(threshold)
65+
66+
for layer_idx, layer_thresholds in thresholds.items():
67+
thresholds[layer_idx] = sum(layer_thresholds) / len(layer_thresholds)
68+
69+
return thresholds
70+
71+
72+
73+
def find_thresholds(
74+
model_name,
75+
dataset_name,
76+
dataset_config,
77+
max_samples,
78+
cett_target,
79+
n_quantiles,
80+
save_path,
81+
device,
82+
):
83+
84+
# Load tokenizer and model
85+
logger.info(f"Loading model: {model_name}")
86+
tokenizer = AutoTokenizer.from_pretrained(model_name)
87+
tokenizer.pad_token = tokenizer.eos_token
88+
89+
model = AutoModelForCausalLM.from_pretrained(
90+
model_name,
91+
torch_dtype=torch.float32,
92+
device_map="auto" if device.type == "cuda" else None,
93+
)
94+
95+
if device.type != "cuda":
96+
model = model.to(device)
97+
98+
model.eval()
99+
model.activation_capture = model.ACTIVATION_CAPTURE(model)
100+
model.activation_capture.register_hooks(hooks=[Hook.UP])
101+
102+
# Load dataset
103+
logger.info(f"Loading dataset: {dataset_name}")
104+
if dataset_config:
105+
dataset = load_dataset(
106+
dataset_name, dataset_config, split="train", streaming=True
107+
)
108+
else:
109+
dataset = load_dataset(dataset_name, split="train", streaming=True)
110+
dataset = dataset.shuffle(buffer_size=10000, seed=42)
111+
112+
def sample_and_tokenize(examples):
113+
"""Sample text chunks before tokenization for efficiency using vectorized operations."""
114+
texts = examples["text"]
115+
tokenized = tokenizer(texts, return_tensors="pt")
116+
117+
# Convert to lists
118+
return {
119+
"text": texts,
120+
"input_ids": tokenized["input_ids"],
121+
}
122+
123+
# Tokenize
124+
dataset = dataset.take(max_samples).map(sample_and_tokenize, batched=False)
125+
dataset = dataset.with_format("torch")
47126

48-
for i in range(activations.size(0)):
49-
neuron_outputs = activations[i] * model.model.layers[0].mlp.down_proj.weight
50-
threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles)
51-
thresholds.append(threshold)
127+
dataloader = TorchDataLoader(dataset, batch_size=1, num_workers=8, pin_memory=False, prefetch_factor=2) # type: ignore
128+
129+
# Compute thresholds for each layer across all dataset entries
130+
logger.info(f"Beginning to compute thresholds using {max_samples} samples")
131+
thresholds = defaultdict(list)
132+
with torch.no_grad():
133+
for batch in tqdm.tqdm(dataloader):
134+
input_ids = batch["input_ids"].to(device)
135+
attention_mask = batch["attention_mask"].to(device)
136+
137+
_ = model(input_ids=input_ids, attention_mask=attention_mask)
138+
139+
for layer,layer_idx in enumerate(model.activation_capture.get_layers()):
140+
activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx]
141+
activations = activations.view(-1, activations.size(-1))
142+
143+
for i in range(activations.size(0)):
144+
neuron_outputs = activations[i] * layer.mlp.down_proj.weight
145+
threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles)
146+
thresholds[layer_idx].append(threshold)
147+
148+
model.activation_capture.clear_captures()
149+
if device.type == "cuda":
150+
torch.cuda.empty_cache()
151+
152+
for layer_idx, layer_thresholds in thresholds.items():
153+
thresholds[layer_idx] = sum(layer_thresholds) / len(layer_thresholds)
154+
155+
# Save layerwise thresholds as record in central json file
156+
if not os.path.exists(save_path):
157+
with open("save_path", mode="r", encoding="utf-8") as read_file:
158+
threshold_dict = json.load(read_file)
159+
else:
160+
threshold_dict = {}
161+
threshold_dict[model_name] = thresholds
162+
with open("save_path", mode="r", encoding="utf-8") as write_file:
163+
json.dump(threshold_dict, write_file)
164+
165+
166+
167+
def parse_args():
168+
parser = argparse.ArgumentParser(
169+
description="Generate training dataset for sparsity predictors"
170+
)
171+
parser.add_argument(
172+
"--model_name",
173+
type=str,
174+
required=True,
175+
help="Name or path of the base model (e.g., meta-llama/Llama-2-7b-hf)",
176+
)
177+
parser.add_argument(
178+
"--dataset",
179+
type=str,
180+
default="allenai/c4",
181+
help="Dataset name (default: allenai/c4)",
182+
)
183+
parser.add_argument(
184+
"--dataset_config",
185+
type=str,
186+
default="en",
187+
help="Dataset configuration (e.g., en for C4)",
188+
)
189+
parser.add_argument(
190+
"--save_path",
191+
type=str,
192+
default="thresholds.json",
193+
help="Path to json file for thresholds",
194+
)
195+
parser.add_argument(
196+
"--max_samples",
197+
type=int,
198+
default=500,
199+
help="Maximum number of samples to process",
200+
)
201+
parser.add_argument(
202+
"--cett_target",
203+
type=float,
204+
default=0.2,
205+
help="Optimal CETT value for threshold-finding",
206+
)
207+
parser.add_argument(
208+
"--n_quantiles",
209+
type=int,
210+
default=500,
211+
help="Number of quantiles to sort neuron outputs into for threshold-finding",
212+
)
213+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
214+
parser.add_argument(
215+
"--device", type=str, default="auto", help="Device to use (auto, cpu, cuda)"
216+
)
217+
218+
return parser.parse_args()
219+
220+
221+
if __name__ == '__main__':
222+
args = parse_args()
223+
224+
# Set seed
225+
set_seed(args.seed)
226+
227+
# Setup device
228+
if args.device == "auto":
229+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
230+
else:
231+
device = torch.device(args.device)
52232

53-
return sum(thresholds)/len(thresholds)
54-
233+
find_thresholds(
234+
model_name=args.model_name,
235+
dataset_name=args.dataset_name,
236+
dataset_config=args.dataset_config,
237+
max_samples=args.max_samples,
238+
cett_target=args.cett_target,
239+
n_quantiles=args.n_quantiles,
240+
save_path=args.save_path,
241+
device=device
242+
)
243+

src/models/phi3/activation_capture_phi.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)