Skip to content

Commit 37a238e

Browse files
committed
updating cett code
Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent 22c7f2b commit 37a238e

File tree

1 file changed

+167
-28
lines changed

1 file changed

+167
-28
lines changed

src/cett.py

Lines changed: 167 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from datasets import load_dataset
99
import torch
10+
from torch.utils.data import DataLoader as TorchDataLoader
1011
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
1112
from transformers.trainer_utils import set_seed
1213

@@ -16,39 +17,173 @@
1617
logging.basicConfig(level=logging.INFO)
1718
logger = logging.getLogger(__name__)
1819

19-
def calculate_threshold_one_token(neuron_outputs, cett_target, n_quantiles=1000):
20-
norms = neuron_outputs.norm(dim=0)
21-
quantiles = norms.quantile(torch.linspace(0,1,n_quantiles))
22-
tot_norm = neuron_outputs.sum(dim=1).norm()
2320

24-
def CETT(threshold):
25-
threshold_norm = ((norms < threshold) * neuron_outputs).sum(dim=1).norm()
26-
return threshold_norm / tot_norm
21+
import copy
22+
class ThresholdEvaluator():
23+
def __init__(self, model, thresholds):
24+
self.model = model
25+
self.thresholds = thresholds
2726

27+
self.compute_neuron_thresholds(thresholds)
28+
29+
self.mlp_outputs = defaultdict(list)
30+
self.handles = []
31+
32+
def get_layers(self):
33+
return self.model.model.layers
34+
35+
def compute_neuron_thresholds(self, thresholds):
36+
n_layers = len(self.get_layers())
37+
self.neuron_thresholds = torch.zeros(n_layers, self.model.config.intermediate_size)
38+
with torch.no_grad():
39+
for layer_idx, layer in self.get_layers():
40+
norms = layer.mlp.down_proj.weight.norm(dim=0)
41+
self.neuron_thresholds[layer_idx] = thresholds[layer_idx] * norms
42+
43+
def _inspect_hook(self, layer_idx):
44+
def hook(module, input, output):
45+
# Just detach, don't clone or move to CPU yet
46+
out = output.view(-1, output.size(-1)).clone().detach()
47+
self.mlp_outputs[layer_idx].append(out)
48+
return output
49+
return hook
50+
51+
def _threshold_hook(self, layer_idx):
52+
def hook(module, input, output):
53+
# Just detach, don't clone or move to CPU yet
54+
mask = (output > self.neuron_thresholds[layer_idx]).bool()
55+
return output * mask
56+
return hook
57+
58+
def apply_thresholds(self):
59+
for layer_idx, layer in enumerate(self.get_layers()):
60+
handle = layer.mlp.act_fn.register_forward_hook(
61+
self._threshold_hook(layer_idx)
62+
)
63+
self.handles.append(handle)
64+
65+
def apply_hooks(self):
66+
for layer_idx, layer in enumerate(self.get_layers()):
67+
handle = layer.mlp.register_forward_hook(
68+
self._inspect_hook(layer_idx)
69+
)
70+
self.handles.append(handle)
71+
72+
def clear_captures(self):
73+
self.mlp_outputs = defaultdict(list)
74+
75+
def remove_hooks(self):
76+
for handle in self.handles:
77+
handle.remove()
78+
self.handles = []
79+
80+
def evaluate(self, inputs):
81+
self.apply_hooks()
82+
83+
with torch.no_grad():
84+
for inp in inputs:
85+
_ = self.model(**inp)
86+
87+
ground_truth_outputs = {
88+
idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs
89+
}
90+
self.clear_captures()
91+
92+
self.apply_thresholds()
93+
with torch.no_grad():
94+
for inp in inputs:
95+
_ = self.model(**inp)
96+
97+
threshold_outputs = {
98+
idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs
99+
}
100+
self.clear_captures()
101+
102+
103+
104+
#
105+
# TODO:
106+
# 1. Test out precomputing down_proj norms and see if that improves performance
107+
# 2. Ensure that the thresholds lead to reasonable results for downstream evaluation
108+
#
109+
#
110+
111+
112+
113+
114+
def cett_from_threshold(neuron_outputs, threshold, norms=None, tot_norm=None):
115+
if not norms: # pass both or neither
116+
norms = norms = neuron_outputs.norm(dim=-2).unsqueeze(-2)
117+
tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1)
118+
threshold_norm = ((norms < threshold) * neuron_outputs).sum(dim=-1).norm(dim=-1)
119+
return threshold_norm / tot_norm
120+
121+
'''
122+
def calculate_threshold_by_token(neuron_outputs, cett_target, n_thresholds=10000):
123+
neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:])
124+
norms = neuron_outputs.norm(dim=-2).unsqueeze(-2)
125+
min_value = norms.min()
126+
max_value = norms.quantile(0.99)
127+
threshold_grid = torch.linspace(min_value, max_value, n_thresholds)
128+
tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1)
129+
thresholds = torch.zeros(neuron_outputs.size(0))
130+
131+
initial_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm)
132+
thresholds[initial_cett < cett_target] = max_value
133+
134+
for j in tqdm(range(neuron_outputs.size(0))):
135+
if thresholds[j] == 0:
136+
left = 0
137+
right = n_thresholds
138+
while left < right:
139+
mid = (left + right) // 2
140+
cett = cett_from_threshold(neuron_outputs[j], threshold_grid[mid], norms=norms[j], tot_norm=tot_norm[j])
141+
if cett <= cett_target:
142+
left = mid + 1
143+
else:
144+
right = mid
145+
thresholds[j] = threshold_grid[left]
146+
return thresholds
147+
'''
148+
149+
def calculate_threshold(neuron_outputs, cett_target, n_thresholds=10000):
150+
neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:])
151+
norms = neuron_outputs.norm(dim=-2).unsqueeze(-2)
152+
tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1)
153+
154+
min_value = norms.min()
155+
max_value = norms.quantile(0.99)
156+
threshold_grid = torch.linspace(min_value, max_value, n_thresholds)
157+
#initial_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm)
158+
#outlier_mask = initial_cett > cett_target
159+
28160
left = 0
29-
right = quantiles.size(0)
30-
threshold = 0
161+
right = n_thresholds
31162
while left < right:
163+
print(left,right)
32164
mid = (left + right) // 2
33-
cett = CETT(quantiles[mid])
165+
#cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm)[outlier_mask].mean()
166+
cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm).mean()
34167
if cett <= cett_target:
35168
left = mid + 1
36-
threshold = quantiles[mid]
37169
else:
38-
right = mid - 1
39-
return threshold
170+
right = mid
171+
return threshold_grid[left]
40172

41173

42174
def find_thresholds(
43175
model_name: str,
44176
dataset_name: str,
45177
dataset_config: str,
46-
max_samples: int,
47-
cett_target: float,
48-
n_quantiles: int,
49178
save_path: str,
50-
seed: int,
51-
device: torch.device,
179+
batch_size: int = 8,
180+
max_samples: int = 128,
181+
max_length: int = 256,
182+
cett_target: float = 0.2,
183+
n_thresholds: int = 10000,
184+
num_workers: int = 8,
185+
seed: int = 42,
186+
device: torch.device = torch.device("cpu"),
52187
):
53188

54189
# Load tokenizer and model
@@ -82,37 +217,41 @@ def find_thresholds(
82217
def sample_and_tokenize(examples):
83218
"""Sample text chunks before tokenization for efficiency using vectorized operations."""
84219
texts = examples["text"]
85-
tokenized = tokenizer(texts, return_tensors="pt")
220+
tokenized = tokenizer(
221+
texts,
222+
max_length=max_length,
223+
truncation=True,
224+
return_tensors="pt"
225+
)
86226

87227
# Convert to lists
88228
return {
89-
"text": texts,
90229
"input_ids": tokenized["input_ids"],
230+
"attention_mask": tokenized["attention_mask"]
91231
}
92232

93233
# Tokenize
94234
dataset = dataset.take(max_samples).map(sample_and_tokenize, batched=False)
95235
dataset = dataset.with_format("torch")
96236

97-
dataloader = TorchDataLoader(dataset, batch_size=1, num_workers=8, pin_memory=False, prefetch_factor=2) # type: ignore
237+
dataloader = TorchDataLoader(dataset, batch_size=1, num_workers=num_workers, pin_memory=False, prefetch_factor=2) # type: ignore
98238

99239
# Compute thresholds for each layer across all dataset entries
100240
logger.info(f"Beginning to compute thresholds using {max_samples} samples")
101241
thresholds = defaultdict(list)
102242
with torch.no_grad():
103243
for batch in tqdm(dataloader, total=max_samples):
104244
input_ids = batch["input_ids"].to(device)
245+
attention_mask = batch["attention_mask"].to(device)
105246

106-
_ = model(input_ids.squeeze(0))
247+
_ = model(input_ids=input_ids.squeeze(0), attention_mask=attention_mask.squeeze(0))
107248

108249
for layer_idx, layer in enumerate(model.activation_capture.get_layers()):
250+
down_weight = layer.mlp.down_proj.weight
109251
activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx]
110-
activations = activations.view(-1, activations.size(-1))
111-
112-
for i in range(activations.size(0)):
113-
neuron_outputs = activations[i] * layer.mlp.down_proj.weight
114-
threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles)
115-
thresholds[layer_idx].append(threshold)
252+
neuron_outputs = activations.unsqueeze(-2) * down_weight
253+
threshold = calculate_threshold(neuron_outputs, cett_target, n_thresholds)
254+
thresholds[layer_idx].append(threshold)
116255

117256
model.activation_capture.clear_captures()
118257
if device.type == "cuda":

0 commit comments

Comments
 (0)