From 7c286b11fa3af6fdb66b6f54bc0972b17bab1a5a Mon Sep 17 00:00:00 2001 From: Kira Selby Date: Thu, 10 Jul 2025 13:19:56 -0400 Subject: [PATCH 1/7] Added basic code for CETT threshold calculation and refactored activation capture. Signed-off-by: Kira Selby --- generate_dataset.py | 10 +- measure_contextual_sparsity.py | 23 +- src/activation_capture.py | 181 +++++------- src/cett.py | 54 ++++ src/modeling_skip.py | 4 +- src/models/phi3/modelling_phi_skip.py | 399 -------------------------- 6 files changed, 133 insertions(+), 538 deletions(-) create mode 100644 src/cett.py delete mode 100644 src/models/phi3/modelling_phi_skip.py diff --git a/generate_dataset.py b/generate_dataset.py index 1c3e6d8..6dc3807 100644 --- a/generate_dataset.py +++ b/generate_dataset.py @@ -43,7 +43,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.trainer_utils import set_seed -from src.activation_capture import ActivationCaptureTraining +from src.activation_capture import Hook # Setup logging logging.basicConfig(level=logging.INFO) @@ -120,14 +120,14 @@ def process_batch( hidden_states_dict = {} mlp_activations_dict = {} for layer_idx in range(num_layers): - hidden_state = model.activation_capture.get_hidden_states(layer_idx)[0] + hidden_state = model.activation_capture.mlp_activations[Hook.IN][layer_idx][0] hidden_states_dict[layer_idx] = ( hidden_state.view(-1, hidden_state.shape[-1]) .cpu() .numpy() .astype(np.float32) ) - mlp_activation = model.activation_capture.get_gate_activations(layer_idx) + mlp_activation = model.activation_capture.mlp_activations[Hook.ACT][layer_idx] mlp_activations_dict[layer_idx] = ( mlp_activation[0] .view(-1, mlp_activation.shape[-1]) @@ -172,8 +172,8 @@ def generate_dataset( model = model.to(device) model.eval() - model.activation_capture = ActivationCaptureTraining(model) - model.activation_capture.register_hooks() + model.activation_capture = model.ACTIVATION_CAPTURE(model) + model.activation_capture.register_hooks(hooks=[Hook.IN, Hook.ACT]) # Get model dimensions hidden_dim = model.config.hidden_size diff --git a/measure_contextual_sparsity.py b/measure_contextual_sparsity.py index 8e70939..21f8036 100644 --- a/measure_contextual_sparsity.py +++ b/measure_contextual_sparsity.py @@ -13,7 +13,7 @@ from transformers.trainer_utils import set_seed import matplotlib.pyplot as plt -from src.activation_capture import ActivationCaptureDefault +from src.activation_capture import Hook # Setup logging logging.basicConfig(level=logging.INFO) @@ -28,16 +28,14 @@ def __init__(self, model, tokenizer, device): self.tokenizer = tokenizer self.device = device - model.activation_capture = ActivationCaptureDefault(model) - model.activation_capture.register_hooks() + model.activation_capture = model.ACTIVATION_CAPTURE(model) + model.activation_capture.register_hooks(hooks=[Hook.ACT]) self.num_layers = len(self.model.activation_capture.get_layers()) self.reset_buffers() def reset_buffers(self): - self.mlp_sparsity = {} - self.mlp_sparsity["gate"] = defaultdict(list) - self.mlp_sparsity["up"] = defaultdict(list) + self.mlp_sparsity = defaultdict(list) self.num_seqs = 0 def process_batch(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): @@ -54,26 +52,19 @@ def process_batch(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): # Compute sparsity for layer_idx in range(self.num_layers): - sparsity_masks_gate = ( - self.model.activation_capture.get_gate_activations(layer_idx) <= 0 - ) - sparsity_masks_up = ( - self.model.activation_capture.get_up_activations(layer_idx) <= 0 + sparsity_masks = ( + self.model.activation_capture.mlp_activations[Hook.ACT][layer_idx] <= 0 ) # Naive sparsity computation self.mlp_sparsity["gate"][layer_idx].append( - sparsity_masks_gate.float().mean().item() - ) - self.mlp_sparsity["up"][layer_idx].append( - sparsity_masks_up.float().mean().item() + sparsity_masks.float().mean().item() ) # Level of sparsity after union over batch dim # union_sparsity_mask = sparsity_masks.any(dim=0) # self.union_sparsity[batch_size][layer_idx].append(union_sparsity_mask.float().mean().item()) - # TODO: Add HNSW sparsity computation for both attn heads and mlp neurons # TODO: Compute union sparsity over multiple different batch sizes # Clear GPU tensors from capture to free memory diff --git a/src/activation_capture.py b/src/activation_capture.py index bbc3ca3..3cedf6c 100644 --- a/src/activation_capture.py +++ b/src/activation_capture.py @@ -1,54 +1,91 @@ -from typing_extensions import override -import torch.nn.functional as F -from abc import ABC, abstractmethod +from enum import Enum +from typing import List -class ActivationCapture(ABC): +class Hook(Enum): + IN = "IN" + ACT = "ACT" + UP = "UP" + OUT = "OUT" + + +class ActivationCapture(): """Helper class to capture activations from model layers.""" - has_gate_proj: bool - has_up_proj: bool + hooks_available: List[Hook] def __init__(self, model): self.model = model - self.mlp_activations = {} + self.mlp_activations = { + hook: {} for hook in self.hooks_available + } self.handles = [] - @abstractmethod - def _register_gate_hook(self, layer_idx, layer): - pass + def _register_in_hook(self, layer_idx, layer): + def hook(module, input, output): + # Just detach, don't clone or move to CPU yet + self.mlp_activations[Hook.IN][layer_idx] = input[0].clone().detach() + return output + handle = layer.mlp.register_forward_hook(hook) + return handle + + def _register_act_hook(self, layer_idx, layer): + def hook(module, input, output): + # Just detach, don't clone or move to CPU yet + self.mlp_activations[Hook.ACT][layer_idx] = input[0].clone().detach() + return output + handle = layer.mlp.act_fn.register_forward_hook(hook) + return handle - @abstractmethod def _register_up_hook(self, layer_idx, layer): - pass + def hook(module, input, output): + # Just detach, don't clone or move to CPU yet + self.mlp_activations[Hook.UP][layer_idx] = input[0].clone().detach() + return output + handle = layer.mlp.down_proj.register_forward_hook(hook) + return handle + + def _register_out_hook(self, layer_idx, layer): + def hook(module, input, output): + # Just detach, don't clone or move to CPU yet + self.mlp_activations[Hook.OUT][layer_idx] = output.clone().detach() + return output + handle = layer.mlp.register_forward_hook(hook) + return handle - @abstractmethod def get_layers(self): - pass - - - @abstractmethod - def get_gate_activations(self, layer_idx): - """Get combined MLP activations for a layer.""" - pass + return self.model.get_decoder().layers - def register_hooks(self): + def register_hooks(self, hooks=(Hook.ACT, Hook.UP, Hook.OUT)): """Register forward hooks to capture activations.""" # Clear any existing hooks self.remove_hooks() # Hook into each transformer layer - for i, layer in enumerate(self.get_layers()): - # Capture MLP gate activations (after activation function) - if self.has_gate_proj: - handle = self._register_gate_hook(i, layer) + for i, layer in enumerate(self.get_layers()): + # Hooks capturing inputs to the MLP layer + if Hook.IN in hooks and Hook.IN in self.hooks_available: + handle = self._register_in_hook(i, layer) if handle is not None: self.handles.append(handle) - - # Also capture up_proj activations - if self.has_up_proj: + + # Hooks capturing inputs to the activation function + if Hook.ACT in hooks and Hook.ACT in self.hooks_available: + handle = self._register_act_hook(i, layer) + if handle is not None: + self.handles.append(handle) + + # Hooks capturing inputs to the down projection + if Hook.UP in hooks and Hook.UP in self.hooks_available: handle = self._register_up_hook(i, layer) if handle is not None: self.handles.append(handle) + + # Hooks capturing the final MLP output + if Hook.OUT in hooks and Hook.OUT in self.hooks_available: + handle = self._register_out_hook(i, layer) + if handle is not None: + self.handles.append(handle) + def remove_hooks(self): """Remove all registered hooks.""" @@ -59,91 +96,3 @@ def remove_hooks(self): def clear_captures(self): """Clear captured activations.""" self.mlp_activations = {} - - - -class ActivationCaptureDefault(ActivationCapture): - """Helper class to capture activations from model layers.""" - has_gate_proj: bool = True - has_up_proj: bool = True - - def get_layers(self): - return self.model.get_decoder().layers - - def _create_mlp_hook(self, layer_idx, proj_type): - def hook(module, input, output): - key = f"{layer_idx}_{proj_type}" - # Just detach, don't clone or move to CPU yet - self.mlp_activations[key] = output.clone().detach() - return output - return hook - - def _register_gate_hook(self, layer_idx, layer): - handle = layer.mlp.gate_proj.register_forward_hook( - self._create_mlp_hook(layer_idx, 'gate') - ) - return handle - - def _register_up_hook(self, layer_idx, layer): - handle = layer.mlp.up_proj.register_forward_hook( - self._create_mlp_hook(layer_idx, 'up') - ) - return handle - - def get_gate_activations(self, layer_idx): - gate_key = f"{layer_idx}_gate" - if gate_key in self.mlp_activations: - gate_act = self.mlp_activations[gate_key] - return F.silu(gate_act) - return None - - def get_up_activations(self, layer_idx): - up_key = f"{layer_idx}_up" - if up_key in self.mlp_activations: - up_act = self.mlp_activations[up_key] - return up_act - return None - -class ActivationCaptureTraining(ActivationCaptureDefault): - """Additional Hidden State capture for training dataset generation""" - def __init__(self, model): - super().__init__(model) - self.hidden_states = {} - - def _create_hidden_state_hook(self, layer_idx, layer): - def hook(module, args, kwargs, output): - # args[0] is the input hidden states to the layer - if len(args) > 0: - # Just detach, don't clone or move to CPU yet - self.hidden_states[layer_idx] = args[0].clone().detach() - return output - return hook - - def _register_hidden_state_hook(self, layer_idx, layer): - handle = layer.register_forward_hook( - self._create_hidden_state_hook(layer_idx, layer), - with_kwargs=True - ) - return handle - - @override - def clear_captures(self): - """Clear captured activations.""" - super().clear_captures() - self.hidden_states = {} - - @override - def register_hooks(self): - """Register forward hooks to capture activations.""" - # Clear any existing hooks - super().register_hooks() - # Hook into each transformer layer - for i, layer in enumerate(self.get_layers()): - # Capture hidden states before MLP - handle = self._register_hidden_state_hook(i, layer) - if handle is not None: - self.handles.append(handle) - - def get_hidden_states(self, layer_idx): - """Get hidden states for a layer.""" - return self.hidden_states[layer_idx] diff --git a/src/cett.py b/src/cett.py new file mode 100644 index 0000000..3549958 --- /dev/null +++ b/src/cett.py @@ -0,0 +1,54 @@ + + +import torch + +from src.activation_capture import ActivationCapture, Hook + +def calculate_threshold_one_token(neuron_outputs, cett_target, n_quantiles=1000): + norms = neuron_outputs.norm(dim=0) + quantiles = norms.quantile(torch.linspace(0,1,n_quantiles)) + tot_norm = neuron_outputs.sum(dim=1).norm() + + def CETT(threshold): + threshold_norm = ((norms < threshold) * neuron_outputs).sum(dim=1).norm() + return threshold_norm / tot_norm + + left = 0 + right = quantiles.size(0) + threshold = 0 + while left < right: + mid = (left + right) // 2 + cett = CETT(quantiles[mid]) + if cett <= cett_target: + left = mid + 1 + threshold = quantiles[mid] + else: + right = mid - 1 + return threshold + + +def find_threshold(model, dataloader, layer_idx, cett_target=0.2, n_quantiles=500): + model.activation_capture = model.ACTIVATION_CAPTURE(model) + model.activation_capture.register_hooks(hooks=[Hook.UP]) + + thresholds = [] + + with torch.no_grad(): + for batch in dataloader: + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + + model.activation_capture.clear_captures() + + _ = model(input_ids=input_ids, attention_mask=attention_mask) + + activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx] + activations = activations.view(-1, activations.size(-1)) + + for i in range(activations.size(0)): + neuron_outputs = activations[i] * model.model.layers[0].mlp.down_proj.weight + threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles) + thresholds.append(threshold) + + return sum(thresholds)/len(thresholds) + \ No newline at end of file diff --git a/src/modeling_skip.py b/src/modeling_skip.py index 18482c3..91b53e1 100644 --- a/src/modeling_skip.py +++ b/src/modeling_skip.py @@ -20,7 +20,7 @@ from transformers.utils.import_utils import is_torch_flex_attn_available from sparse_transformers import WeightCache, sparse_mlp_forward -from src.activation_capture import ActivationCaptureDefault +from src.activation_capture import ActivationCapture if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask @@ -352,7 +352,7 @@ def forward( def build_skip_connection_model_for_causal_lm(pretrained_model_class: type[PreTrainedModel], base_model_class: type[PreTrainedModel]): - ACTIVATION_CAPTURE = ActivationCaptureDefault + ACTIVATION_CAPTURE = ActivationCapture class SkipConnectionModelForCausalLM(pretrained_model_class, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/models/phi3/modelling_phi_skip.py b/src/models/phi3/modelling_phi_skip.py deleted file mode 100644 index d4303c4..0000000 --- a/src/models/phi3/modelling_phi_skip.py +++ /dev/null @@ -1,399 +0,0 @@ - -import math -from typing import Optional, Tuple, Union - -import torch -from torch import nn - -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.processing_utils import Unpack -from transformers.utils import logging -from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache -from transformers.modeling_utils import PreTrainedModel -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs - -from transformers.models.phi3.modeling_phi3 import( - Phi3MLP, Phi3Attention, Phi3RMSNorm, Phi3RotaryEmbedding, -) - -from transformers.utils.import_utils import is_torch_flex_attn_available - -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from transformers.integrations.flex_attention import make_flex_block_causal_mask - -from src.models.phi3.configuration_phi_skip import Phi3SkipConnectionConfig -from src.modeling_skip import SkipMLP, SkipDecoderLayer, FastLoRAProjection, build_skip_connection_model, build_skip_connection_model_for_causal_lm -from .activation_capture_phi import ActivationCapturePhi3 -logger = logging.get_logger(__name__) - - -class Phi3SkipMLP(SkipMLP): - def __init__(self, hidden_size, intermediate_size, sparsity): - super().__init__(hidden_size, intermediate_size, sparsity, False, "silu") - self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False) - - def _fix_unloaded_weights(self): - gate_proj_weight, up_proj_weight = self.gate_up_proj.weight.chunk(2, dim=0) - self.gate_proj.load_state_dict({'weight': gate_proj_weight}, assign=True) - self.up_proj.load_state_dict({'weight': up_proj_weight}, assign=True) - del self.gate_up_proj - return self - - -class Phi3SkipDecoderLayer(SkipDecoderLayer): - def _init_components(self, config, layer_idx): - self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) - self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) - self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) - - def _set_mlp_train(self, config, layer_idx): - self.mlp = Phi3MLP(config) - - def _set_mlp_inference(self, config, layer_idx): - self.mlp = Phi3SkipMLP( - config.hidden_size, - config.intermediate_size, - config.sparsity, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ): - """ - Args: - hidden_states (`torch.FloatTensor`): - input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range - `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_value (`Cache`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - if not self.training: # Use PyTorch's built-in training flag - self._compute_binary_mask(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if self.training and self.is_training_config: - predictor_loss = self.compute_predictor_loss(hidden_states) - else: - predictor_loss = None - hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs, predictor_loss - - -class Phi3SkipPreTrainedModel(PreTrainedModel): - config_class = Phi3SkipConnectionConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Phi3SkipDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_attention_backend = True - _version = "0.0.5" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Phi3RMSNorm): - module.weight.data.fill_(1.0) - -Phi3SkipConnectionModelBase = build_skip_connection_model(Phi3SkipPreTrainedModel) - -class Phi3SkipConnectionModel(Phi3SkipConnectionModelBase): - def _init_components(self, config): - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # type: ignore - self.layers = nn.ModuleList( - [Phi3SkipDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Phi3RotaryEmbedding(config=config) - - def _initialize_unloaded_weights(self): - for module in self.modules(): - if any(hasattr(p, 'is_meta') and p.is_meta for p in module.parameters()): - if isinstance(module, FastLoRAProjection): - module = module.to_empty(device="cpu") - with torch.no_grad(): - torch.nn.init.xavier_normal_(module.down.weight) - torch.nn.init.zeros_(module.up.weight) # Initialize up projection to zeros for stable training - elif isinstance(module, Phi3SkipMLP): - module._fix_weights() - - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Phi3Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - - -Phi3SkipConnectionForCausalLMBase = build_skip_connection_model_for_causal_lm(Phi3SkipPreTrainedModel, Phi3SkipConnectionModel) - -class Phi3SkipConnectionForCausalLM(Phi3SkipConnectionForCausalLMBase): - ACTIVATION_CAPTURE = ActivationCapturePhi3 - - _keys_to_ignore_on_load_missing = [ - "model.layers.*.mlp.combined_proj_buffer", - "model.layers.*.mlp.down_proj_buffer", - "model.layers.*.mlp.init_mask", - "model.layers.*.mlp.weight_cache", - "model.layers.*.mlp_lora_proj.down.weight", - "model.layers.*.mlp_lora_proj.intermediate", - "model.layers.*.mlp_lora_proj.output", - "model.layers.*.mlp_lora_proj.up.weight", - "model.layers.*.mlp_mask", - "model.layers.*.mlp.gate_proj.weight", - "model.layers.*.mlp.up_proj.weight", - "model.layers.*.standard_mlp.gate_proj.weight", - "model.layers.*.standard_mlp.up_proj.weight", - "model.layers.*.standard_mlp.down_proj.weight" - ] - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - logits_to_keep=None, - **kwargs, - ): - # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the - # process - - # When the first time input length reached long and short factor switching point, enforce re-compute cache - # It will cause downside of slower at this single token position, however, better than current failure. - if ( - past_key_values - and self.config.rope_scaling - and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 - ): - past_length = cache_position[0] - if past_length <= self.config.original_max_position_embeddings: - past_key_values = None - - model_inputs = super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - cache_position=cache_position, - position_ids=position_ids, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - **kwargs, - ) - return model_inputs - \ No newline at end of file From f7e988cc86dd8226afa199f19b83f6294dd0b9b6 Mon Sep 17 00:00:00 2001 From: Kira Selby Date: Mon, 14 Jul 2025 09:39:47 -0400 Subject: [PATCH 2/7] Basic script for CETT and fix phi3 Signed-off-by: Kira Selby --- src/cett.py | 209 +++++++++++- src/models/phi3/activation_capture_phi.py | 58 ---- src/models/phi3/modelling_phi_skip.py | 399 ++++++++++++++++++++++ 3 files changed, 598 insertions(+), 68 deletions(-) delete mode 100644 src/models/phi3/activation_capture_phi.py create mode 100644 src/models/phi3/modelling_phi_skip.py diff --git a/src/cett.py b/src/cett.py index 3549958..6a53ed7 100644 --- a/src/cett.py +++ b/src/cett.py @@ -1,9 +1,21 @@ +from collections import defaultdict +import logging +import os +import json +import tqdm +import argparse - +from datasets import load_dataset import torch +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM +from transformers.trainer_utils import set_seed from src.activation_capture import ActivationCapture, Hook +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + def calculate_threshold_one_token(neuron_outputs, cett_target, n_quantiles=1000): norms = neuron_outputs.norm(dim=0) quantiles = norms.quantile(torch.linspace(0,1,n_quantiles)) @@ -31,7 +43,7 @@ def find_threshold(model, dataloader, layer_idx, cett_target=0.2, n_quantiles=50 model.activation_capture = model.ACTIVATION_CAPTURE(model) model.activation_capture.register_hooks(hooks=[Hook.UP]) - thresholds = [] + thresholds = defaultdict(list) with torch.no_grad(): for batch in dataloader: @@ -42,13 +54,190 @@ def find_threshold(model, dataloader, layer_idx, cett_target=0.2, n_quantiles=50 _ = model(input_ids=input_ids, attention_mask=attention_mask) - activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx] - activations = activations.view(-1, activations.size(-1)) + for layer,layer_idx in enumerate(model.activation_capture.get_layers()): + activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx] + activations = activations.view(-1, activations.size(-1)) + + for i in range(activations.size(0)): + neuron_outputs = activations[i] * layer.mlp.down_proj.weight + threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles) + thresholds[layer_idx].append(threshold) + + for layer_idx, layer_thresholds in thresholds.items(): + thresholds[layer_idx] = sum(layer_thresholds) / len(layer_thresholds) + + return thresholds + + + +def find_thresholds( + model_name, + dataset_name, + dataset_config, + max_samples, + cett_target, + n_quantiles, + save_path, + device, + ): + + # Load tokenizer and model + logger.info(f"Loading model: {model_name}") + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float32, + device_map="auto" if device.type == "cuda" else None, + ) + + if device.type != "cuda": + model = model.to(device) + + model.eval() + model.activation_capture = model.ACTIVATION_CAPTURE(model) + model.activation_capture.register_hooks(hooks=[Hook.UP]) + + # Load dataset + logger.info(f"Loading dataset: {dataset_name}") + if dataset_config: + dataset = load_dataset( + dataset_name, dataset_config, split="train", streaming=True + ) + else: + dataset = load_dataset(dataset_name, split="train", streaming=True) + dataset = dataset.shuffle(buffer_size=10000, seed=42) + + def sample_and_tokenize(examples): + """Sample text chunks before tokenization for efficiency using vectorized operations.""" + texts = examples["text"] + tokenized = tokenizer(texts, return_tensors="pt") + + # Convert to lists + return { + "text": texts, + "input_ids": tokenized["input_ids"], + } + + # Tokenize + dataset = dataset.take(max_samples).map(sample_and_tokenize, batched=False) + dataset = dataset.with_format("torch") - for i in range(activations.size(0)): - neuron_outputs = activations[i] * model.model.layers[0].mlp.down_proj.weight - threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles) - thresholds.append(threshold) + dataloader = TorchDataLoader(dataset, batch_size=1, num_workers=8, pin_memory=False, prefetch_factor=2) # type: ignore + + # Compute thresholds for each layer across all dataset entries + logger.info(f"Beginning to compute thresholds using {max_samples} samples") + thresholds = defaultdict(list) + with torch.no_grad(): + for batch in tqdm.tqdm(dataloader): + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) + + _ = model(input_ids=input_ids, attention_mask=attention_mask) + + for layer,layer_idx in enumerate(model.activation_capture.get_layers()): + activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx] + activations = activations.view(-1, activations.size(-1)) + + for i in range(activations.size(0)): + neuron_outputs = activations[i] * layer.mlp.down_proj.weight + threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles) + thresholds[layer_idx].append(threshold) + + model.activation_capture.clear_captures() + if device.type == "cuda": + torch.cuda.empty_cache() + + for layer_idx, layer_thresholds in thresholds.items(): + thresholds[layer_idx] = sum(layer_thresholds) / len(layer_thresholds) + + # Save layerwise thresholds as record in central json file + if not os.path.exists(save_path): + with open("save_path", mode="r", encoding="utf-8") as read_file: + threshold_dict = json.load(read_file) + else: + threshold_dict = {} + threshold_dict[model_name] = thresholds + with open("save_path", mode="r", encoding="utf-8") as write_file: + json.dump(threshold_dict, write_file) + + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Generate training dataset for sparsity predictors" + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + help="Name or path of the base model (e.g., meta-llama/Llama-2-7b-hf)", + ) + parser.add_argument( + "--dataset", + type=str, + default="allenai/c4", + help="Dataset name (default: allenai/c4)", + ) + parser.add_argument( + "--dataset_config", + type=str, + default="en", + help="Dataset configuration (e.g., en for C4)", + ) + parser.add_argument( + "--save_path", + type=str, + default="thresholds.json", + help="Path to json file for thresholds", + ) + parser.add_argument( + "--max_samples", + type=int, + default=500, + help="Maximum number of samples to process", + ) + parser.add_argument( + "--cett_target", + type=float, + default=0.2, + help="Optimal CETT value for threshold-finding", + ) + parser.add_argument( + "--n_quantiles", + type=int, + default=500, + help="Number of quantiles to sort neuron outputs into for threshold-finding", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument( + "--device", type=str, default="auto", help="Device to use (auto, cpu, cuda)" + ) + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + # Set seed + set_seed(args.seed) + + # Setup device + if args.device == "auto": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(args.device) - return sum(thresholds)/len(thresholds) - \ No newline at end of file + find_thresholds( + model_name=args.model_name, + dataset_name=args.dataset_name, + dataset_config=args.dataset_config, + max_samples=args.max_samples, + cett_target=args.cett_target, + n_quantiles=args.n_quantiles, + save_path=args.save_path, + device=device + ) + diff --git a/src/models/phi3/activation_capture_phi.py b/src/models/phi3/activation_capture_phi.py deleted file mode 100644 index e2e6cf9..0000000 --- a/src/models/phi3/activation_capture_phi.py +++ /dev/null @@ -1,58 +0,0 @@ -from src.activation_capture import ActivationCapture -import torch.nn.functional as F - - - -class ActivationCapturePhi3(ActivationCapture): - """Helper class to capture activations from model layers.""" - has_gate_proj: bool = True - has_up_proj: bool = True - - def get_layers(self, model): - return model.model.layers - - def _register_gate_hook(self, layer_idx, layer): - def hook(module, input, output): - key1 = f"{layer_idx}_{'gate'}" - key2 = f"{layer_idx}_{'up'}" - # Just detach, don't clone or move to CPU yet - gate_outputs, up_outputs = output.chunk(2, dim=1) - self.mlp_activations[key1] = gate_outputs.detach() - self.mlp_activations[key2] = up_outputs.detach() - return output - handle = layer.mlp.gate_up_proj.register_forward_hook(hook) - return handle - - def _register_up_hook(self, layer_idx, layer): - def hook(module, input, output): - key = f"{layer_idx}_{'up'}" - # Just detach, don't clone or move to CPU yet - up_outputs = output.chunk(2, dim=1)[1] - self.mlp_activations[key] = up_outputs.detach() - return output - handle = layer.mlp.gate_up_proj.register_forward_hook(hook) - return handle - - def get_gate_activations(self, layer_idx): - """Get combined MLP activations for a layer.""" - gate_key = f"{layer_idx}_gate" - if gate_key in self.mlp_activations: - gate_act = self.mlp_activations[gate_key] - return F.silu(gate_act) - return None - - def get_mlp_activations(self, layer_idx): - """Get combined MLP activations for a layer.""" - gate_key = f"{layer_idx}_gate" - up_key = f"{layer_idx}_up" - - if gate_key in self.mlp_activations and up_key in self.mlp_activations: - # Compute gated activations: gate(x) * up(x) - gate_act = self.mlp_activations[gate_key] - up_act = self.mlp_activations[up_key] - - # Apply SwiGLU activation: silu(gate) * up - gated_act = F.silu(gate_act) * up_act - return gated_act - - return None \ No newline at end of file diff --git a/src/models/phi3/modelling_phi_skip.py b/src/models/phi3/modelling_phi_skip.py new file mode 100644 index 0000000..d4303c4 --- /dev/null +++ b/src/models/phi3/modelling_phi_skip.py @@ -0,0 +1,399 @@ + +import math +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.processing_utils import Unpack +from transformers.utils import logging +from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache +from transformers.modeling_utils import PreTrainedModel +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs + +from transformers.models.phi3.modeling_phi3 import( + Phi3MLP, Phi3Attention, Phi3RMSNorm, Phi3RotaryEmbedding, +) + +from transformers.utils.import_utils import is_torch_flex_attn_available + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from transformers.integrations.flex_attention import make_flex_block_causal_mask + +from src.models.phi3.configuration_phi_skip import Phi3SkipConnectionConfig +from src.modeling_skip import SkipMLP, SkipDecoderLayer, FastLoRAProjection, build_skip_connection_model, build_skip_connection_model_for_causal_lm +from .activation_capture_phi import ActivationCapturePhi3 +logger = logging.get_logger(__name__) + + +class Phi3SkipMLP(SkipMLP): + def __init__(self, hidden_size, intermediate_size, sparsity): + super().__init__(hidden_size, intermediate_size, sparsity, False, "silu") + self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False) + + def _fix_unloaded_weights(self): + gate_proj_weight, up_proj_weight = self.gate_up_proj.weight.chunk(2, dim=0) + self.gate_proj.load_state_dict({'weight': gate_proj_weight}, assign=True) + self.up_proj.load_state_dict({'weight': up_proj_weight}, assign=True) + del self.gate_up_proj + return self + + +class Phi3SkipDecoderLayer(SkipDecoderLayer): + def _init_components(self, config, layer_idx): + self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + def _set_mlp_train(self, config, layer_idx): + self.mlp = Phi3MLP(config) + + def _set_mlp_inference(self, config, layer_idx): + self.mlp = Phi3SkipMLP( + config.hidden_size, + config.intermediate_size, + config.sparsity, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if not self.training: # Use PyTorch's built-in training flag + self._compute_binary_mask(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.training and self.is_training_config: + predictor_loss = self.compute_predictor_loss(hidden_states) + else: + predictor_loss = None + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs, predictor_loss + + +class Phi3SkipPreTrainedModel(PreTrainedModel): + config_class = Phi3SkipConnectionConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi3SkipDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Phi3RMSNorm): + module.weight.data.fill_(1.0) + +Phi3SkipConnectionModelBase = build_skip_connection_model(Phi3SkipPreTrainedModel) + +class Phi3SkipConnectionModel(Phi3SkipConnectionModelBase): + def _init_components(self, config): + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # type: ignore + self.layers = nn.ModuleList( + [Phi3SkipDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Phi3RotaryEmbedding(config=config) + + def _initialize_unloaded_weights(self): + for module in self.modules(): + if any(hasattr(p, 'is_meta') and p.is_meta for p in module.parameters()): + if isinstance(module, FastLoRAProjection): + module = module.to_empty(device="cpu") + with torch.no_grad(): + torch.nn.init.xavier_normal_(module.down.weight) + torch.nn.init.zeros_(module.up.weight) # Initialize up projection to zeros for stable training + elif isinstance(module, Phi3SkipMLP): + module._fix_weights() + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Phi3Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( + cache_position.reshape(-1, 1) - text_config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +Phi3SkipConnectionForCausalLMBase = build_skip_connection_model_for_causal_lm(Phi3SkipPreTrainedModel, Phi3SkipConnectionModel) + +class Phi3SkipConnectionForCausalLM(Phi3SkipConnectionForCausalLMBase): + ACTIVATION_CAPTURE = ActivationCapturePhi3 + + _keys_to_ignore_on_load_missing = [ + "model.layers.*.mlp.combined_proj_buffer", + "model.layers.*.mlp.down_proj_buffer", + "model.layers.*.mlp.init_mask", + "model.layers.*.mlp.weight_cache", + "model.layers.*.mlp_lora_proj.down.weight", + "model.layers.*.mlp_lora_proj.intermediate", + "model.layers.*.mlp_lora_proj.output", + "model.layers.*.mlp_lora_proj.up.weight", + "model.layers.*.mlp_mask", + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + "model.layers.*.standard_mlp.gate_proj.weight", + "model.layers.*.standard_mlp.up_proj.weight", + "model.layers.*.standard_mlp.down_proj.weight" + ] + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + \ No newline at end of file From 22c7f2bc2f553b32681598ec078ffe5c6c3be085 Mon Sep 17 00:00:00 2001 From: Kira Selby Date: Mon, 14 Jul 2025 10:24:41 -0400 Subject: [PATCH 3/7] Fixes to activation capture, gemma and CETT Signed-off-by: Kira Selby --- src/activation_capture.py | 6 +- src/cett.py | 66 +++++-------------- .../gemma3n/activation_capture_gemma.py | 11 ---- src/models/gemma3n/modelling_gemma_skip.py | 2 - src/models/phi3/modelling_phi_skip.py | 3 - 5 files changed, 22 insertions(+), 66 deletions(-) delete mode 100644 src/models/gemma3n/activation_capture_gemma.py diff --git a/src/activation_capture.py b/src/activation_capture.py index 3cedf6c..30b8107 100644 --- a/src/activation_capture.py +++ b/src/activation_capture.py @@ -11,7 +11,7 @@ class Hook(Enum): class ActivationCapture(): """Helper class to capture activations from model layers.""" - hooks_available: List[Hook] + hooks_available: List[Hook] = [Hook.IN, Hook.ACT, Hook.UP, Hook.OUT] def __init__(self, model): self.model = model @@ -95,4 +95,6 @@ def remove_hooks(self): def clear_captures(self): """Clear captured activations.""" - self.mlp_activations = {} + self.mlp_activations = { + hook: {} for hook in self.hooks_available + } diff --git a/src/cett.py b/src/cett.py index 6a53ed7..0a189b0 100644 --- a/src/cett.py +++ b/src/cett.py @@ -2,7 +2,7 @@ import logging import os import json -import tqdm +from tqdm import tqdm import argparse from datasets import load_dataset @@ -39,46 +39,16 @@ def CETT(threshold): return threshold -def find_threshold(model, dataloader, layer_idx, cett_target=0.2, n_quantiles=500): - model.activation_capture = model.ACTIVATION_CAPTURE(model) - model.activation_capture.register_hooks(hooks=[Hook.UP]) - - thresholds = defaultdict(list) - - with torch.no_grad(): - for batch in dataloader: - input_ids = batch["input_ids"] - attention_mask = batch["attention_mask"] - - model.activation_capture.clear_captures() - - _ = model(input_ids=input_ids, attention_mask=attention_mask) - - for layer,layer_idx in enumerate(model.activation_capture.get_layers()): - activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx] - activations = activations.view(-1, activations.size(-1)) - - for i in range(activations.size(0)): - neuron_outputs = activations[i] * layer.mlp.down_proj.weight - threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles) - thresholds[layer_idx].append(threshold) - - for layer_idx, layer_thresholds in thresholds.items(): - thresholds[layer_idx] = sum(layer_thresholds) / len(layer_thresholds) - - return thresholds - - - def find_thresholds( - model_name, - dataset_name, - dataset_config, - max_samples, - cett_target, - n_quantiles, - save_path, - device, + model_name: str, + dataset_name: str, + dataset_config: str, + max_samples: int, + cett_target: float, + n_quantiles: int, + save_path: str, + seed: int, + device: torch.device, ): # Load tokenizer and model @@ -96,7 +66,7 @@ def find_thresholds( model = model.to(device) model.eval() - model.activation_capture = model.ACTIVATION_CAPTURE(model) + model.activation_capture = ActivationCapture(model) model.activation_capture.register_hooks(hooks=[Hook.UP]) # Load dataset @@ -107,7 +77,7 @@ def find_thresholds( ) else: dataset = load_dataset(dataset_name, split="train", streaming=True) - dataset = dataset.shuffle(buffer_size=10000, seed=42) + dataset = dataset.shuffle(buffer_size=10000, seed=seed) def sample_and_tokenize(examples): """Sample text chunks before tokenization for efficiency using vectorized operations.""" @@ -130,13 +100,12 @@ def sample_and_tokenize(examples): logger.info(f"Beginning to compute thresholds using {max_samples} samples") thresholds = defaultdict(list) with torch.no_grad(): - for batch in tqdm.tqdm(dataloader): + for batch in tqdm(dataloader, total=max_samples): input_ids = batch["input_ids"].to(device) - attention_mask = batch["attention_mask"].to(device) - _ = model(input_ids=input_ids, attention_mask=attention_mask) + _ = model(input_ids.squeeze(0)) - for layer,layer_idx in enumerate(model.activation_capture.get_layers()): + for layer_idx, layer in enumerate(model.activation_capture.get_layers()): activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx] activations = activations.view(-1, activations.size(-1)) @@ -232,12 +201,13 @@ def parse_args(): find_thresholds( model_name=args.model_name, - dataset_name=args.dataset_name, + dataset_name=args.dataset, dataset_config=args.dataset_config, max_samples=args.max_samples, cett_target=args.cett_target, n_quantiles=args.n_quantiles, save_path=args.save_path, - device=device + seed=args.seed, + device=device, ) diff --git a/src/models/gemma3n/activation_capture_gemma.py b/src/models/gemma3n/activation_capture_gemma.py deleted file mode 100644 index 8364d6f..0000000 --- a/src/models/gemma3n/activation_capture_gemma.py +++ /dev/null @@ -1,11 +0,0 @@ -from src.activation_capture import ActivationCaptureDefault - - -class ActivationCaptureGemma3n(ActivationCaptureDefault): - """Helper class to capture activations from model layers.""" - - def _register_gate_hook(self, layer_idx, layer): - handle = layer.mlp.act_fn.register_forward_hook( - self._create_mlp_hook(layer_idx, 'gate') - ) - return handle diff --git a/src/models/gemma3n/modelling_gemma_skip.py b/src/models/gemma3n/modelling_gemma_skip.py index 7fcdd8a..cf0480f 100644 --- a/src/models/gemma3n/modelling_gemma_skip.py +++ b/src/models/gemma3n/modelling_gemma_skip.py @@ -27,7 +27,6 @@ from sparse_transformers import sparse_mlp_forward from src.models.gemma3n.configuration_gemma_skip import Gemma3nSkipConnectionConfig -from src.models.gemma3n.activation_capture_gemma import ActivationCaptureGemma3n from src.modeling_skip import SkipMLP, SkipDecoderLayer, build_skip_connection_model, build_skip_connection_model_for_causal_lm logger = logging.get_logger(__name__) @@ -413,7 +412,6 @@ def project_per_layer_inputs( Gemma3nSkipConnectionForCausalLMBase = build_skip_connection_model_for_causal_lm(Gemma3nSkipPreTrainedModel, Gemma3nSkipConnectionModel) class Gemma3nSkipConnectionForCausalLM(Gemma3nSkipConnectionForCausalLMBase): - ACTIVATION_CAPTURE = ActivationCaptureGemma3n _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/models/phi3/modelling_phi_skip.py b/src/models/phi3/modelling_phi_skip.py index d4303c4..f3851b7 100644 --- a/src/models/phi3/modelling_phi_skip.py +++ b/src/models/phi3/modelling_phi_skip.py @@ -25,7 +25,6 @@ from src.models.phi3.configuration_phi_skip import Phi3SkipConnectionConfig from src.modeling_skip import SkipMLP, SkipDecoderLayer, FastLoRAProjection, build_skip_connection_model, build_skip_connection_model_for_causal_lm -from .activation_capture_phi import ActivationCapturePhi3 logger = logging.get_logger(__name__) @@ -339,8 +338,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Phi3SkipConnectionForCausalLMBase = build_skip_connection_model_for_causal_lm(Phi3SkipPreTrainedModel, Phi3SkipConnectionModel) class Phi3SkipConnectionForCausalLM(Phi3SkipConnectionForCausalLMBase): - ACTIVATION_CAPTURE = ActivationCapturePhi3 - _keys_to_ignore_on_load_missing = [ "model.layers.*.mlp.combined_proj_buffer", "model.layers.*.mlp.down_proj_buffer", From 37a238eeaa828d7da8fde4e742b968da4e5300e5 Mon Sep 17 00:00:00 2001 From: Kira Selby Date: Tue, 15 Jul 2025 17:14:48 -0400 Subject: [PATCH 4/7] updating cett code Signed-off-by: Kira Selby --- src/cett.py | 195 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 167 insertions(+), 28 deletions(-) diff --git a/src/cett.py b/src/cett.py index 0a189b0..617668b 100644 --- a/src/cett.py +++ b/src/cett.py @@ -7,6 +7,7 @@ from datasets import load_dataset import torch +from torch.utils.data import DataLoader as TorchDataLoader from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM from transformers.trainer_utils import set_seed @@ -16,39 +17,173 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def calculate_threshold_one_token(neuron_outputs, cett_target, n_quantiles=1000): - norms = neuron_outputs.norm(dim=0) - quantiles = norms.quantile(torch.linspace(0,1,n_quantiles)) - tot_norm = neuron_outputs.sum(dim=1).norm() - def CETT(threshold): - threshold_norm = ((norms < threshold) * neuron_outputs).sum(dim=1).norm() - return threshold_norm / tot_norm +import copy +class ThresholdEvaluator(): + def __init__(self, model, thresholds): + self.model = model + self.thresholds = thresholds + self.compute_neuron_thresholds(thresholds) + + self.mlp_outputs = defaultdict(list) + self.handles = [] + + def get_layers(self): + return self.model.model.layers + + def compute_neuron_thresholds(self, thresholds): + n_layers = len(self.get_layers()) + self.neuron_thresholds = torch.zeros(n_layers, self.model.config.intermediate_size) + with torch.no_grad(): + for layer_idx, layer in self.get_layers(): + norms = layer.mlp.down_proj.weight.norm(dim=0) + self.neuron_thresholds[layer_idx] = thresholds[layer_idx] * norms + + def _inspect_hook(self, layer_idx): + def hook(module, input, output): + # Just detach, don't clone or move to CPU yet + out = output.view(-1, output.size(-1)).clone().detach() + self.mlp_outputs[layer_idx].append(out) + return output + return hook + + def _threshold_hook(self, layer_idx): + def hook(module, input, output): + # Just detach, don't clone or move to CPU yet + mask = (output > self.neuron_thresholds[layer_idx]).bool() + return output * mask + return hook + + def apply_thresholds(self): + for layer_idx, layer in enumerate(self.get_layers()): + handle = layer.mlp.act_fn.register_forward_hook( + self._threshold_hook(layer_idx) + ) + self.handles.append(handle) + + def apply_hooks(self): + for layer_idx, layer in enumerate(self.get_layers()): + handle = layer.mlp.register_forward_hook( + self._inspect_hook(layer_idx) + ) + self.handles.append(handle) + + def clear_captures(self): + self.mlp_outputs = defaultdict(list) + + def remove_hooks(self): + for handle in self.handles: + handle.remove() + self.handles = [] + + def evaluate(self, inputs): + self.apply_hooks() + + with torch.no_grad(): + for inp in inputs: + _ = self.model(**inp) + + ground_truth_outputs = { + idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs + } + self.clear_captures() + + self.apply_thresholds() + with torch.no_grad(): + for inp in inputs: + _ = self.model(**inp) + + threshold_outputs = { + idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs + } + self.clear_captures() + + + +# +# TODO: +# 1. Test out precomputing down_proj norms and see if that improves performance +# 2. Ensure that the thresholds lead to reasonable results for downstream evaluation +# +# + + + + +def cett_from_threshold(neuron_outputs, threshold, norms=None, tot_norm=None): + if not norms: # pass both or neither + norms = norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) + tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) + threshold_norm = ((norms < threshold) * neuron_outputs).sum(dim=-1).norm(dim=-1) + return threshold_norm / tot_norm + +''' +def calculate_threshold_by_token(neuron_outputs, cett_target, n_thresholds=10000): + neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:]) + norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) + min_value = norms.min() + max_value = norms.quantile(0.99) + threshold_grid = torch.linspace(min_value, max_value, n_thresholds) + tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) + thresholds = torch.zeros(neuron_outputs.size(0)) + + initial_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm) + thresholds[initial_cett < cett_target] = max_value + + for j in tqdm(range(neuron_outputs.size(0))): + if thresholds[j] == 0: + left = 0 + right = n_thresholds + while left < right: + mid = (left + right) // 2 + cett = cett_from_threshold(neuron_outputs[j], threshold_grid[mid], norms=norms[j], tot_norm=tot_norm[j]) + if cett <= cett_target: + left = mid + 1 + else: + right = mid + thresholds[j] = threshold_grid[left] + return thresholds +''' + +def calculate_threshold(neuron_outputs, cett_target, n_thresholds=10000): + neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:]) + norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) + tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) + + min_value = norms.min() + max_value = norms.quantile(0.99) + threshold_grid = torch.linspace(min_value, max_value, n_thresholds) + #initial_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm) + #outlier_mask = initial_cett > cett_target + left = 0 - right = quantiles.size(0) - threshold = 0 + right = n_thresholds while left < right: + print(left,right) mid = (left + right) // 2 - cett = CETT(quantiles[mid]) + #cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm)[outlier_mask].mean() + cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm).mean() if cett <= cett_target: left = mid + 1 - threshold = quantiles[mid] else: - right = mid - 1 - return threshold + right = mid + return threshold_grid[left] def find_thresholds( model_name: str, dataset_name: str, dataset_config: str, - max_samples: int, - cett_target: float, - n_quantiles: int, save_path: str, - seed: int, - device: torch.device, + batch_size: int = 8, + max_samples: int = 128, + max_length: int = 256, + cett_target: float = 0.2, + n_thresholds: int = 10000, + num_workers: int = 8, + seed: int = 42, + device: torch.device = torch.device("cpu"), ): # Load tokenizer and model @@ -82,19 +217,24 @@ def find_thresholds( def sample_and_tokenize(examples): """Sample text chunks before tokenization for efficiency using vectorized operations.""" texts = examples["text"] - tokenized = tokenizer(texts, return_tensors="pt") + tokenized = tokenizer( + texts, + max_length=max_length, + truncation=True, + return_tensors="pt" + ) # Convert to lists return { - "text": texts, "input_ids": tokenized["input_ids"], + "attention_mask": tokenized["attention_mask"] } # Tokenize dataset = dataset.take(max_samples).map(sample_and_tokenize, batched=False) dataset = dataset.with_format("torch") - dataloader = TorchDataLoader(dataset, batch_size=1, num_workers=8, pin_memory=False, prefetch_factor=2) # type: ignore + dataloader = TorchDataLoader(dataset, batch_size=1, num_workers=num_workers, pin_memory=False, prefetch_factor=2) # type: ignore # Compute thresholds for each layer across all dataset entries logger.info(f"Beginning to compute thresholds using {max_samples} samples") @@ -102,17 +242,16 @@ def sample_and_tokenize(examples): with torch.no_grad(): for batch in tqdm(dataloader, total=max_samples): input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) - _ = model(input_ids.squeeze(0)) + _ = model(input_ids=input_ids.squeeze(0), attention_mask=attention_mask.squeeze(0)) for layer_idx, layer in enumerate(model.activation_capture.get_layers()): + down_weight = layer.mlp.down_proj.weight activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx] - activations = activations.view(-1, activations.size(-1)) - - for i in range(activations.size(0)): - neuron_outputs = activations[i] * layer.mlp.down_proj.weight - threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles) - thresholds[layer_idx].append(threshold) + neuron_outputs = activations.unsqueeze(-2) * down_weight + threshold = calculate_threshold(neuron_outputs, cett_target, n_thresholds) + thresholds[layer_idx].append(threshold) model.activation_capture.clear_captures() if device.type == "cuda": From 5de9727f05f36b32b98179d2ccf171751747e413 Mon Sep 17 00:00:00 2001 From: Kira Selby Date: Fri, 18 Jul 2025 10:09:33 -0400 Subject: [PATCH 5/7] Filter outliers Signed-off-by: Kira Selby --- src/cett.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cett.py b/src/cett.py index 617668b..8c23a1f 100644 --- a/src/cett.py +++ b/src/cett.py @@ -154,16 +154,16 @@ def calculate_threshold(neuron_outputs, cett_target, n_thresholds=10000): min_value = norms.min() max_value = norms.quantile(0.99) threshold_grid = torch.linspace(min_value, max_value, n_thresholds) - #initial_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm) - #outlier_mask = initial_cett > cett_target + max_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm) + outlier_mask = max_cett > cett_target left = 0 right = n_thresholds while left < right: print(left,right) mid = (left + right) // 2 - #cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm)[outlier_mask].mean() - cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm).mean() + cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm) # Compute CETT for each token + cett = cett[outlier_mask].mean() # Remove outliers and take average if cett <= cett_target: left = mid + 1 else: From 76059002a82dfe8bb2274c66c98d792cf81b28c9 Mon Sep 17 00:00:00 2001 From: Kira Selby Date: Fri, 18 Jul 2025 10:53:19 -0400 Subject: [PATCH 6/7] Precompute column norms to significantly speed up computation Signed-off-by: Kira Selby --- src/cett.py | 157 +++++++--------------------------------------------- 1 file changed, 21 insertions(+), 136 deletions(-) diff --git a/src/cett.py b/src/cett.py index 8c23a1f..77a86e6 100644 --- a/src/cett.py +++ b/src/cett.py @@ -18,151 +18,34 @@ logger = logging.getLogger(__name__) -import copy -class ThresholdEvaluator(): - def __init__(self, model, thresholds): - self.model = model - self.thresholds = thresholds - - self.compute_neuron_thresholds(thresholds) - - self.mlp_outputs = defaultdict(list) - self.handles = [] - - def get_layers(self): - return self.model.model.layers - - def compute_neuron_thresholds(self, thresholds): - n_layers = len(self.get_layers()) - self.neuron_thresholds = torch.zeros(n_layers, self.model.config.intermediate_size) - with torch.no_grad(): - for layer_idx, layer in self.get_layers(): - norms = layer.mlp.down_proj.weight.norm(dim=0) - self.neuron_thresholds[layer_idx] = thresholds[layer_idx] * norms - - def _inspect_hook(self, layer_idx): - def hook(module, input, output): - # Just detach, don't clone or move to CPU yet - out = output.view(-1, output.size(-1)).clone().detach() - self.mlp_outputs[layer_idx].append(out) - return output - return hook - - def _threshold_hook(self, layer_idx): - def hook(module, input, output): - # Just detach, don't clone or move to CPU yet - mask = (output > self.neuron_thresholds[layer_idx]).bool() - return output * mask - return hook - - def apply_thresholds(self): - for layer_idx, layer in enumerate(self.get_layers()): - handle = layer.mlp.act_fn.register_forward_hook( - self._threshold_hook(layer_idx) - ) - self.handles.append(handle) - - def apply_hooks(self): - for layer_idx, layer in enumerate(self.get_layers()): - handle = layer.mlp.register_forward_hook( - self._inspect_hook(layer_idx) - ) - self.handles.append(handle) - - def clear_captures(self): - self.mlp_outputs = defaultdict(list) - - def remove_hooks(self): - for handle in self.handles: - handle.remove() - self.handles = [] - - def evaluate(self, inputs): - self.apply_hooks() - - with torch.no_grad(): - for inp in inputs: - _ = self.model(**inp) - - ground_truth_outputs = { - idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs - } - self.clear_captures() - - self.apply_thresholds() - with torch.no_grad(): - for inp in inputs: - _ = self.model(**inp) - - threshold_outputs = { - idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs - } - self.clear_captures() - - - -# -# TODO: -# 1. Test out precomputing down_proj norms and see if that improves performance -# 2. Ensure that the thresholds lead to reasonable results for downstream evaluation -# -# - +def cett_from_threshold(activations, down_weight, threshold, norms=None, tot_norm=None): + if norms is None: + col_norms = down_weight.norm(dim=0) + norms = activations.abs() * col_norms + tot_norm = activations.matmul(down_weight.t()).norm(dim=-1) + masked_act = activations * (norms < threshold) + threshold_norm = masked_act.matmul(down_weight.t()).norm(dim=-1) + return threshold_norm / tot_norm -def cett_from_threshold(neuron_outputs, threshold, norms=None, tot_norm=None): - if not norms: # pass both or neither - norms = norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) - tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) - threshold_norm = ((norms < threshold) * neuron_outputs).sum(dim=-1).norm(dim=-1) - return threshold_norm / tot_norm +def calculate_threshold(activations, down_weight, col_norms, cett_target, n_thresholds=1000): + norms = activations.abs() * col_norms + output = activations.matmul(down_weight.t()) + tot_norm = output.norm(dim=-1) -''' -def calculate_threshold_by_token(neuron_outputs, cett_target, n_thresholds=10000): - neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:]) - norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) min_value = norms.min() max_value = norms.quantile(0.99) threshold_grid = torch.linspace(min_value, max_value, n_thresholds) - tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) - thresholds = torch.zeros(neuron_outputs.size(0)) - - initial_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm) - thresholds[initial_cett < cett_target] = max_value - - for j in tqdm(range(neuron_outputs.size(0))): - if thresholds[j] == 0: - left = 0 - right = n_thresholds - while left < right: - mid = (left + right) // 2 - cett = cett_from_threshold(neuron_outputs[j], threshold_grid[mid], norms=norms[j], tot_norm=tot_norm[j]) - if cett <= cett_target: - left = mid + 1 - else: - right = mid - thresholds[j] = threshold_grid[left] - return thresholds -''' - -def calculate_threshold(neuron_outputs, cett_target, n_thresholds=10000): - neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:]) - norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) - tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) - - min_value = norms.min() - max_value = norms.quantile(0.99) - threshold_grid = torch.linspace(min_value, max_value, n_thresholds) - max_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm) + max_cett = cett_from_threshold(activations, down_weight, max_value, norms=norms, tot_norm=tot_norm) outlier_mask = max_cett > cett_target - + left = 0 right = n_thresholds while left < right: - print(left,right) + #print(left,right) mid = (left + right) // 2 - cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm) # Compute CETT for each token + cett = cett_from_threshold(activations, down_weight, threshold_grid[mid], norms=norms, tot_norm=tot_norm) # Compute CETT for each token cett = cett[outlier_mask].mean() # Remove outliers and take average if cett <= cett_target: left = mid + 1 @@ -240,17 +123,19 @@ def sample_and_tokenize(examples): logger.info(f"Beginning to compute thresholds using {max_samples} samples") thresholds = defaultdict(list) with torch.no_grad(): + all_col_norms = {layer_idx: layer.mlp.down_proj.weight.norm(dim=0) \ + for layer_idx, layer in enumerate(model.activation_capture.get_layers())} for batch in tqdm(dataloader, total=max_samples): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) - _ = model(input_ids=input_ids.squeeze(0), attention_mask=attention_mask.squeeze(0)) + _ = model(input_ids=input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1)) for layer_idx, layer in enumerate(model.activation_capture.get_layers()): down_weight = layer.mlp.down_proj.weight + col_norms = all_col_norms[layer_idx] activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx] - neuron_outputs = activations.unsqueeze(-2) * down_weight - threshold = calculate_threshold(neuron_outputs, cett_target, n_thresholds) + threshold = calculate_threshold(activations, down_weight, col_norms, cett_target, n_thresholds) thresholds[layer_idx].append(threshold) model.activation_capture.clear_captures() From 0f17b8e3fde8a9d946f743d2df0c19b3d0147075 Mon Sep 17 00:00:00 2001 From: Kira Selby Date: Fri, 18 Jul 2025 12:58:27 -0400 Subject: [PATCH 7/7] Working implementation Signed-off-by: Kira Selby --- src/cett.py | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/src/cett.py b/src/cett.py index 77a86e6..c4df938 100644 --- a/src/cett.py +++ b/src/cett.py @@ -18,6 +18,39 @@ logger = logging.getLogger(__name__) +# from https://github.com/pytorch/pytorch/issues/64947#issuecomment-2810054982 +def quantile(tensor, q, dim=None, keepdim=False): + """ + Computes the quantile of the input tensor along the specified dimension. + + Parameters: + tensor (torch.Tensor): The input tensor. + q (float): The quantile to compute, should be a float between 0 and 1. + dim (int): The dimension to reduce. If None, the tensor is flattened. + keepdim (bool): Whether to keep the reduced dimension in the output. + Returns: + torch.Tensor: The quantile value(s) along the specified dimension. + """ + assert 0 <= q <= 1, "\n\nquantile value should be a float between 0 and 1.\n\n" + + if dim is None: + tensor = tensor.flatten() + dim = 0 + + sorted_tensor, _ = torch.sort(tensor, dim=dim) + num_elements = sorted_tensor.size(dim) + index = q * (num_elements - 1) + lower_index = int(index) + upper_index = min(lower_index + 1, num_elements - 1) + lower_value = sorted_tensor.select(dim, lower_index) + upper_value = sorted_tensor.select(dim, upper_index) + # linear interpolation + weight = index - lower_index + quantile_value = (1 - weight) * lower_value + weight * upper_value + + return quantile_value.unsqueeze(dim) if keepdim else quantile_value + + def cett_from_threshold(activations, down_weight, threshold, norms=None, tot_norm=None): if norms is None: @@ -35,7 +68,7 @@ def calculate_threshold(activations, down_weight, col_norms, cett_target, n_thre tot_norm = output.norm(dim=-1) min_value = norms.min() - max_value = norms.quantile(0.99) + max_value = quantile(norms, 0.99) threshold_grid = torch.linspace(min_value, max_value, n_thresholds) max_cett = cett_from_threshold(activations, down_weight, max_value, norms=norms, tot_norm=tot_norm) outlier_mask = max_cett > cett_target @@ -117,7 +150,7 @@ def sample_and_tokenize(examples): dataset = dataset.take(max_samples).map(sample_and_tokenize, batched=False) dataset = dataset.with_format("torch") - dataloader = TorchDataLoader(dataset, batch_size=1, num_workers=num_workers, pin_memory=False, prefetch_factor=2) # type: ignore + dataloader = TorchDataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=False, prefetch_factor=2) # type: ignore # Compute thresholds for each layer across all dataset entries logger.info(f"Beginning to compute thresholds using {max_samples} samples") @@ -125,7 +158,7 @@ def sample_and_tokenize(examples): with torch.no_grad(): all_col_norms = {layer_idx: layer.mlp.down_proj.weight.norm(dim=0) \ for layer_idx, layer in enumerate(model.activation_capture.get_layers())} - for batch in tqdm(dataloader, total=max_samples): + for batch in tqdm(dataloader, total=max_samples//batch_size): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) @@ -146,7 +179,7 @@ def sample_and_tokenize(examples): thresholds[layer_idx] = sum(layer_thresholds) / len(layer_thresholds) # Save layerwise thresholds as record in central json file - if not os.path.exists(save_path): + if os.path.exists(save_path): with open("save_path", mode="r", encoding="utf-8") as read_file: threshold_dict = json.load(read_file) else: