Skip to content

Commit 22c7f2b

Browse files
committed
Fixes to activation capture, gemma and CETT
Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent f7e988c commit 22c7f2b

File tree

5 files changed

+22
-66
lines changed

5 files changed

+22
-66
lines changed

src/activation_capture.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class Hook(Enum):
1111

1212
class ActivationCapture():
1313
"""Helper class to capture activations from model layers."""
14-
hooks_available: List[Hook]
14+
hooks_available: List[Hook] = [Hook.IN, Hook.ACT, Hook.UP, Hook.OUT]
1515

1616
def __init__(self, model):
1717
self.model = model
@@ -95,4 +95,6 @@ def remove_hooks(self):
9595

9696
def clear_captures(self):
9797
"""Clear captured activations."""
98-
self.mlp_activations = {}
98+
self.mlp_activations = {
99+
hook: {} for hook in self.hooks_available
100+
}

src/cett.py

Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44
import json
5-
import tqdm
5+
from tqdm import tqdm
66
import argparse
77

88
from datasets import load_dataset
@@ -39,46 +39,16 @@ def CETT(threshold):
3939
return threshold
4040

4141

42-
def find_threshold(model, dataloader, layer_idx, cett_target=0.2, n_quantiles=500):
43-
model.activation_capture = model.ACTIVATION_CAPTURE(model)
44-
model.activation_capture.register_hooks(hooks=[Hook.UP])
45-
46-
thresholds = defaultdict(list)
47-
48-
with torch.no_grad():
49-
for batch in dataloader:
50-
input_ids = batch["input_ids"]
51-
attention_mask = batch["attention_mask"]
52-
53-
model.activation_capture.clear_captures()
54-
55-
_ = model(input_ids=input_ids, attention_mask=attention_mask)
56-
57-
for layer,layer_idx in enumerate(model.activation_capture.get_layers()):
58-
activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx]
59-
activations = activations.view(-1, activations.size(-1))
60-
61-
for i in range(activations.size(0)):
62-
neuron_outputs = activations[i] * layer.mlp.down_proj.weight
63-
threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles)
64-
thresholds[layer_idx].append(threshold)
65-
66-
for layer_idx, layer_thresholds in thresholds.items():
67-
thresholds[layer_idx] = sum(layer_thresholds) / len(layer_thresholds)
68-
69-
return thresholds
70-
71-
72-
7342
def find_thresholds(
74-
model_name,
75-
dataset_name,
76-
dataset_config,
77-
max_samples,
78-
cett_target,
79-
n_quantiles,
80-
save_path,
81-
device,
43+
model_name: str,
44+
dataset_name: str,
45+
dataset_config: str,
46+
max_samples: int,
47+
cett_target: float,
48+
n_quantiles: int,
49+
save_path: str,
50+
seed: int,
51+
device: torch.device,
8252
):
8353

8454
# Load tokenizer and model
@@ -96,7 +66,7 @@ def find_thresholds(
9666
model = model.to(device)
9767

9868
model.eval()
99-
model.activation_capture = model.ACTIVATION_CAPTURE(model)
69+
model.activation_capture = ActivationCapture(model)
10070
model.activation_capture.register_hooks(hooks=[Hook.UP])
10171

10272
# Load dataset
@@ -107,7 +77,7 @@ def find_thresholds(
10777
)
10878
else:
10979
dataset = load_dataset(dataset_name, split="train", streaming=True)
110-
dataset = dataset.shuffle(buffer_size=10000, seed=42)
80+
dataset = dataset.shuffle(buffer_size=10000, seed=seed)
11181

11282
def sample_and_tokenize(examples):
11383
"""Sample text chunks before tokenization for efficiency using vectorized operations."""
@@ -130,13 +100,12 @@ def sample_and_tokenize(examples):
130100
logger.info(f"Beginning to compute thresholds using {max_samples} samples")
131101
thresholds = defaultdict(list)
132102
with torch.no_grad():
133-
for batch in tqdm.tqdm(dataloader):
103+
for batch in tqdm(dataloader, total=max_samples):
134104
input_ids = batch["input_ids"].to(device)
135-
attention_mask = batch["attention_mask"].to(device)
136105

137-
_ = model(input_ids=input_ids, attention_mask=attention_mask)
106+
_ = model(input_ids.squeeze(0))
138107

139-
for layer,layer_idx in enumerate(model.activation_capture.get_layers()):
108+
for layer_idx, layer in enumerate(model.activation_capture.get_layers()):
140109
activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx]
141110
activations = activations.view(-1, activations.size(-1))
142111

@@ -232,12 +201,13 @@ def parse_args():
232201

233202
find_thresholds(
234203
model_name=args.model_name,
235-
dataset_name=args.dataset_name,
204+
dataset_name=args.dataset,
236205
dataset_config=args.dataset_config,
237206
max_samples=args.max_samples,
238207
cett_target=args.cett_target,
239208
n_quantiles=args.n_quantiles,
240209
save_path=args.save_path,
241-
device=device
210+
seed=args.seed,
211+
device=device,
242212
)
243213

src/models/gemma3n/activation_capture_gemma.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/models/gemma3n/modelling_gemma_skip.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from sparse_transformers import sparse_mlp_forward
2828

2929
from src.models.gemma3n.configuration_gemma_skip import Gemma3nSkipConnectionConfig
30-
from src.models.gemma3n.activation_capture_gemma import ActivationCaptureGemma3n
3130
from src.modeling_skip import SkipMLP, SkipDecoderLayer, build_skip_connection_model, build_skip_connection_model_for_causal_lm
3231

3332
logger = logging.get_logger(__name__)
@@ -413,7 +412,6 @@ def project_per_layer_inputs(
413412
Gemma3nSkipConnectionForCausalLMBase = build_skip_connection_model_for_causal_lm(Gemma3nSkipPreTrainedModel, Gemma3nSkipConnectionModel)
414413

415414
class Gemma3nSkipConnectionForCausalLM(Gemma3nSkipConnectionForCausalLMBase):
416-
ACTIVATION_CAPTURE = ActivationCaptureGemma3n
417415
_tied_weights_keys = ["lm_head.weight"]
418416
_tp_plan = {"lm_head": "colwise_rep"}
419417
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

src/models/phi3/modelling_phi_skip.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from src.models.phi3.configuration_phi_skip import Phi3SkipConnectionConfig
2727
from src.modeling_skip import SkipMLP, SkipDecoderLayer, FastLoRAProjection, build_skip_connection_model, build_skip_connection_model_for_causal_lm
28-
from .activation_capture_phi import ActivationCapturePhi3
2928
logger = logging.get_logger(__name__)
3029

3130

@@ -339,8 +338,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
339338
Phi3SkipConnectionForCausalLMBase = build_skip_connection_model_for_causal_lm(Phi3SkipPreTrainedModel, Phi3SkipConnectionModel)
340339

341340
class Phi3SkipConnectionForCausalLM(Phi3SkipConnectionForCausalLMBase):
342-
ACTIVATION_CAPTURE = ActivationCapturePhi3
343-
344341
_keys_to_ignore_on_load_missing = [
345342
"model.layers.*.mlp.combined_proj_buffer",
346343
"model.layers.*.mlp.down_proj_buffer",

0 commit comments

Comments
 (0)