diff --git a/generate_dataset.py b/generate_dataset.py index 1c3e6d8..6dc3807 100644 --- a/generate_dataset.py +++ b/generate_dataset.py @@ -43,7 +43,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.trainer_utils import set_seed -from src.activation_capture import ActivationCaptureTraining +from src.activation_capture import Hook # Setup logging logging.basicConfig(level=logging.INFO) @@ -120,14 +120,14 @@ def process_batch( hidden_states_dict = {} mlp_activations_dict = {} for layer_idx in range(num_layers): - hidden_state = model.activation_capture.get_hidden_states(layer_idx)[0] + hidden_state = model.activation_capture.mlp_activations[Hook.IN][layer_idx][0] hidden_states_dict[layer_idx] = ( hidden_state.view(-1, hidden_state.shape[-1]) .cpu() .numpy() .astype(np.float32) ) - mlp_activation = model.activation_capture.get_gate_activations(layer_idx) + mlp_activation = model.activation_capture.mlp_activations[Hook.ACT][layer_idx] mlp_activations_dict[layer_idx] = ( mlp_activation[0] .view(-1, mlp_activation.shape[-1]) @@ -172,8 +172,8 @@ def generate_dataset( model = model.to(device) model.eval() - model.activation_capture = ActivationCaptureTraining(model) - model.activation_capture.register_hooks() + model.activation_capture = model.ACTIVATION_CAPTURE(model) + model.activation_capture.register_hooks(hooks=[Hook.IN, Hook.ACT]) # Get model dimensions hidden_dim = model.config.hidden_size diff --git a/measure_contextual_sparsity.py b/measure_contextual_sparsity.py index 8e70939..21f8036 100644 --- a/measure_contextual_sparsity.py +++ b/measure_contextual_sparsity.py @@ -13,7 +13,7 @@ from transformers.trainer_utils import set_seed import matplotlib.pyplot as plt -from src.activation_capture import ActivationCaptureDefault +from src.activation_capture import Hook # Setup logging logging.basicConfig(level=logging.INFO) @@ -28,16 +28,14 @@ def __init__(self, model, tokenizer, device): self.tokenizer = tokenizer self.device = device - model.activation_capture = ActivationCaptureDefault(model) - model.activation_capture.register_hooks() + model.activation_capture = model.ACTIVATION_CAPTURE(model) + model.activation_capture.register_hooks(hooks=[Hook.ACT]) self.num_layers = len(self.model.activation_capture.get_layers()) self.reset_buffers() def reset_buffers(self): - self.mlp_sparsity = {} - self.mlp_sparsity["gate"] = defaultdict(list) - self.mlp_sparsity["up"] = defaultdict(list) + self.mlp_sparsity = defaultdict(list) self.num_seqs = 0 def process_batch(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): @@ -54,26 +52,19 @@ def process_batch(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): # Compute sparsity for layer_idx in range(self.num_layers): - sparsity_masks_gate = ( - self.model.activation_capture.get_gate_activations(layer_idx) <= 0 - ) - sparsity_masks_up = ( - self.model.activation_capture.get_up_activations(layer_idx) <= 0 + sparsity_masks = ( + self.model.activation_capture.mlp_activations[Hook.ACT][layer_idx] <= 0 ) # Naive sparsity computation self.mlp_sparsity["gate"][layer_idx].append( - sparsity_masks_gate.float().mean().item() - ) - self.mlp_sparsity["up"][layer_idx].append( - sparsity_masks_up.float().mean().item() + sparsity_masks.float().mean().item() ) # Level of sparsity after union over batch dim # union_sparsity_mask = sparsity_masks.any(dim=0) # self.union_sparsity[batch_size][layer_idx].append(union_sparsity_mask.float().mean().item()) - # TODO: Add HNSW sparsity computation for both attn heads and mlp neurons # TODO: Compute union sparsity over multiple different batch sizes # Clear GPU tensors from capture to free memory diff --git a/src/activation_capture.py b/src/activation_capture.py index bbc3ca3..30b8107 100644 --- a/src/activation_capture.py +++ b/src/activation_capture.py @@ -1,54 +1,91 @@ -from typing_extensions import override -import torch.nn.functional as F -from abc import ABC, abstractmethod +from enum import Enum +from typing import List -class ActivationCapture(ABC): +class Hook(Enum): + IN = "IN" + ACT = "ACT" + UP = "UP" + OUT = "OUT" + + +class ActivationCapture(): """Helper class to capture activations from model layers.""" - has_gate_proj: bool - has_up_proj: bool + hooks_available: List[Hook] = [Hook.IN, Hook.ACT, Hook.UP, Hook.OUT] def __init__(self, model): self.model = model - self.mlp_activations = {} + self.mlp_activations = { + hook: {} for hook in self.hooks_available + } self.handles = [] - @abstractmethod - def _register_gate_hook(self, layer_idx, layer): - pass + def _register_in_hook(self, layer_idx, layer): + def hook(module, input, output): + # Just detach, don't clone or move to CPU yet + self.mlp_activations[Hook.IN][layer_idx] = input[0].clone().detach() + return output + handle = layer.mlp.register_forward_hook(hook) + return handle + + def _register_act_hook(self, layer_idx, layer): + def hook(module, input, output): + # Just detach, don't clone or move to CPU yet + self.mlp_activations[Hook.ACT][layer_idx] = input[0].clone().detach() + return output + handle = layer.mlp.act_fn.register_forward_hook(hook) + return handle - @abstractmethod def _register_up_hook(self, layer_idx, layer): - pass + def hook(module, input, output): + # Just detach, don't clone or move to CPU yet + self.mlp_activations[Hook.UP][layer_idx] = input[0].clone().detach() + return output + handle = layer.mlp.down_proj.register_forward_hook(hook) + return handle + + def _register_out_hook(self, layer_idx, layer): + def hook(module, input, output): + # Just detach, don't clone or move to CPU yet + self.mlp_activations[Hook.OUT][layer_idx] = output.clone().detach() + return output + handle = layer.mlp.register_forward_hook(hook) + return handle - @abstractmethod def get_layers(self): - pass - - - @abstractmethod - def get_gate_activations(self, layer_idx): - """Get combined MLP activations for a layer.""" - pass + return self.model.get_decoder().layers - def register_hooks(self): + def register_hooks(self, hooks=(Hook.ACT, Hook.UP, Hook.OUT)): """Register forward hooks to capture activations.""" # Clear any existing hooks self.remove_hooks() # Hook into each transformer layer - for i, layer in enumerate(self.get_layers()): - # Capture MLP gate activations (after activation function) - if self.has_gate_proj: - handle = self._register_gate_hook(i, layer) + for i, layer in enumerate(self.get_layers()): + # Hooks capturing inputs to the MLP layer + if Hook.IN in hooks and Hook.IN in self.hooks_available: + handle = self._register_in_hook(i, layer) + if handle is not None: + self.handles.append(handle) + + # Hooks capturing inputs to the activation function + if Hook.ACT in hooks and Hook.ACT in self.hooks_available: + handle = self._register_act_hook(i, layer) if handle is not None: self.handles.append(handle) - - # Also capture up_proj activations - if self.has_up_proj: + + # Hooks capturing inputs to the down projection + if Hook.UP in hooks and Hook.UP in self.hooks_available: handle = self._register_up_hook(i, layer) if handle is not None: self.handles.append(handle) + + # Hooks capturing the final MLP output + if Hook.OUT in hooks and Hook.OUT in self.hooks_available: + handle = self._register_out_hook(i, layer) + if handle is not None: + self.handles.append(handle) + def remove_hooks(self): """Remove all registered hooks.""" @@ -58,92 +95,6 @@ def remove_hooks(self): def clear_captures(self): """Clear captured activations.""" - self.mlp_activations = {} - - - -class ActivationCaptureDefault(ActivationCapture): - """Helper class to capture activations from model layers.""" - has_gate_proj: bool = True - has_up_proj: bool = True - - def get_layers(self): - return self.model.get_decoder().layers - - def _create_mlp_hook(self, layer_idx, proj_type): - def hook(module, input, output): - key = f"{layer_idx}_{proj_type}" - # Just detach, don't clone or move to CPU yet - self.mlp_activations[key] = output.clone().detach() - return output - return hook - - def _register_gate_hook(self, layer_idx, layer): - handle = layer.mlp.gate_proj.register_forward_hook( - self._create_mlp_hook(layer_idx, 'gate') - ) - return handle - - def _register_up_hook(self, layer_idx, layer): - handle = layer.mlp.up_proj.register_forward_hook( - self._create_mlp_hook(layer_idx, 'up') - ) - return handle - - def get_gate_activations(self, layer_idx): - gate_key = f"{layer_idx}_gate" - if gate_key in self.mlp_activations: - gate_act = self.mlp_activations[gate_key] - return F.silu(gate_act) - return None - - def get_up_activations(self, layer_idx): - up_key = f"{layer_idx}_up" - if up_key in self.mlp_activations: - up_act = self.mlp_activations[up_key] - return up_act - return None - -class ActivationCaptureTraining(ActivationCaptureDefault): - """Additional Hidden State capture for training dataset generation""" - def __init__(self, model): - super().__init__(model) - self.hidden_states = {} - - def _create_hidden_state_hook(self, layer_idx, layer): - def hook(module, args, kwargs, output): - # args[0] is the input hidden states to the layer - if len(args) > 0: - # Just detach, don't clone or move to CPU yet - self.hidden_states[layer_idx] = args[0].clone().detach() - return output - return hook - - def _register_hidden_state_hook(self, layer_idx, layer): - handle = layer.register_forward_hook( - self._create_hidden_state_hook(layer_idx, layer), - with_kwargs=True - ) - return handle - - @override - def clear_captures(self): - """Clear captured activations.""" - super().clear_captures() - self.hidden_states = {} - - @override - def register_hooks(self): - """Register forward hooks to capture activations.""" - # Clear any existing hooks - super().register_hooks() - # Hook into each transformer layer - for i, layer in enumerate(self.get_layers()): - # Capture hidden states before MLP - handle = self._register_hidden_state_hook(i, layer) - if handle is not None: - self.handles.append(handle) - - def get_hidden_states(self, layer_idx): - """Get hidden states for a layer.""" - return self.hidden_states[layer_idx] + self.mlp_activations = { + hook: {} for hook in self.hooks_available + } diff --git a/src/cett.py b/src/cett.py new file mode 100644 index 0000000..c4df938 --- /dev/null +++ b/src/cett.py @@ -0,0 +1,270 @@ +from collections import defaultdict +import logging +import os +import json +from tqdm import tqdm +import argparse + +from datasets import load_dataset +import torch +from torch.utils.data import DataLoader as TorchDataLoader +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM +from transformers.trainer_utils import set_seed + +from src.activation_capture import ActivationCapture, Hook + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# from https://github.com/pytorch/pytorch/issues/64947#issuecomment-2810054982 +def quantile(tensor, q, dim=None, keepdim=False): + """ + Computes the quantile of the input tensor along the specified dimension. + + Parameters: + tensor (torch.Tensor): The input tensor. + q (float): The quantile to compute, should be a float between 0 and 1. + dim (int): The dimension to reduce. If None, the tensor is flattened. + keepdim (bool): Whether to keep the reduced dimension in the output. + Returns: + torch.Tensor: The quantile value(s) along the specified dimension. + """ + assert 0 <= q <= 1, "\n\nquantile value should be a float between 0 and 1.\n\n" + + if dim is None: + tensor = tensor.flatten() + dim = 0 + + sorted_tensor, _ = torch.sort(tensor, dim=dim) + num_elements = sorted_tensor.size(dim) + index = q * (num_elements - 1) + lower_index = int(index) + upper_index = min(lower_index + 1, num_elements - 1) + lower_value = sorted_tensor.select(dim, lower_index) + upper_value = sorted_tensor.select(dim, upper_index) + # linear interpolation + weight = index - lower_index + quantile_value = (1 - weight) * lower_value + weight * upper_value + + return quantile_value.unsqueeze(dim) if keepdim else quantile_value + + + +def cett_from_threshold(activations, down_weight, threshold, norms=None, tot_norm=None): + if norms is None: + col_norms = down_weight.norm(dim=0) + norms = activations.abs() * col_norms + tot_norm = activations.matmul(down_weight.t()).norm(dim=-1) + masked_act = activations * (norms < threshold) + threshold_norm = masked_act.matmul(down_weight.t()).norm(dim=-1) + return threshold_norm / tot_norm + + +def calculate_threshold(activations, down_weight, col_norms, cett_target, n_thresholds=1000): + norms = activations.abs() * col_norms + output = activations.matmul(down_weight.t()) + tot_norm = output.norm(dim=-1) + + min_value = norms.min() + max_value = quantile(norms, 0.99) + threshold_grid = torch.linspace(min_value, max_value, n_thresholds) + max_cett = cett_from_threshold(activations, down_weight, max_value, norms=norms, tot_norm=tot_norm) + outlier_mask = max_cett > cett_target + + left = 0 + right = n_thresholds + while left < right: + #print(left,right) + mid = (left + right) // 2 + cett = cett_from_threshold(activations, down_weight, threshold_grid[mid], norms=norms, tot_norm=tot_norm) # Compute CETT for each token + cett = cett[outlier_mask].mean() # Remove outliers and take average + if cett <= cett_target: + left = mid + 1 + else: + right = mid + return threshold_grid[left] + + +def find_thresholds( + model_name: str, + dataset_name: str, + dataset_config: str, + save_path: str, + batch_size: int = 8, + max_samples: int = 128, + max_length: int = 256, + cett_target: float = 0.2, + n_thresholds: int = 10000, + num_workers: int = 8, + seed: int = 42, + device: torch.device = torch.device("cpu"), + ): + + # Load tokenizer and model + logger.info(f"Loading model: {model_name}") + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float32, + device_map="auto" if device.type == "cuda" else None, + ) + + if device.type != "cuda": + model = model.to(device) + + model.eval() + model.activation_capture = ActivationCapture(model) + model.activation_capture.register_hooks(hooks=[Hook.UP]) + + # Load dataset + logger.info(f"Loading dataset: {dataset_name}") + if dataset_config: + dataset = load_dataset( + dataset_name, dataset_config, split="train", streaming=True + ) + else: + dataset = load_dataset(dataset_name, split="train", streaming=True) + dataset = dataset.shuffle(buffer_size=10000, seed=seed) + + def sample_and_tokenize(examples): + """Sample text chunks before tokenization for efficiency using vectorized operations.""" + texts = examples["text"] + tokenized = tokenizer( + texts, + max_length=max_length, + truncation=True, + return_tensors="pt" + ) + + # Convert to lists + return { + "input_ids": tokenized["input_ids"], + "attention_mask": tokenized["attention_mask"] + } + + # Tokenize + dataset = dataset.take(max_samples).map(sample_and_tokenize, batched=False) + dataset = dataset.with_format("torch") + + dataloader = TorchDataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=False, prefetch_factor=2) # type: ignore + + # Compute thresholds for each layer across all dataset entries + logger.info(f"Beginning to compute thresholds using {max_samples} samples") + thresholds = defaultdict(list) + with torch.no_grad(): + all_col_norms = {layer_idx: layer.mlp.down_proj.weight.norm(dim=0) \ + for layer_idx, layer in enumerate(model.activation_capture.get_layers())} + for batch in tqdm(dataloader, total=max_samples//batch_size): + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) + + _ = model(input_ids=input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1)) + + for layer_idx, layer in enumerate(model.activation_capture.get_layers()): + down_weight = layer.mlp.down_proj.weight + col_norms = all_col_norms[layer_idx] + activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx] + threshold = calculate_threshold(activations, down_weight, col_norms, cett_target, n_thresholds) + thresholds[layer_idx].append(threshold) + + model.activation_capture.clear_captures() + if device.type == "cuda": + torch.cuda.empty_cache() + + for layer_idx, layer_thresholds in thresholds.items(): + thresholds[layer_idx] = sum(layer_thresholds) / len(layer_thresholds) + + # Save layerwise thresholds as record in central json file + if os.path.exists(save_path): + with open("save_path", mode="r", encoding="utf-8") as read_file: + threshold_dict = json.load(read_file) + else: + threshold_dict = {} + threshold_dict[model_name] = thresholds + with open("save_path", mode="r", encoding="utf-8") as write_file: + json.dump(threshold_dict, write_file) + + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Generate training dataset for sparsity predictors" + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + help="Name or path of the base model (e.g., meta-llama/Llama-2-7b-hf)", + ) + parser.add_argument( + "--dataset", + type=str, + default="allenai/c4", + help="Dataset name (default: allenai/c4)", + ) + parser.add_argument( + "--dataset_config", + type=str, + default="en", + help="Dataset configuration (e.g., en for C4)", + ) + parser.add_argument( + "--save_path", + type=str, + default="thresholds.json", + help="Path to json file for thresholds", + ) + parser.add_argument( + "--max_samples", + type=int, + default=500, + help="Maximum number of samples to process", + ) + parser.add_argument( + "--cett_target", + type=float, + default=0.2, + help="Optimal CETT value for threshold-finding", + ) + parser.add_argument( + "--n_quantiles", + type=int, + default=500, + help="Number of quantiles to sort neuron outputs into for threshold-finding", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument( + "--device", type=str, default="auto", help="Device to use (auto, cpu, cuda)" + ) + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + # Set seed + set_seed(args.seed) + + # Setup device + if args.device == "auto": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(args.device) + + find_thresholds( + model_name=args.model_name, + dataset_name=args.dataset, + dataset_config=args.dataset_config, + max_samples=args.max_samples, + cett_target=args.cett_target, + n_quantiles=args.n_quantiles, + save_path=args.save_path, + seed=args.seed, + device=device, + ) + diff --git a/src/modeling_skip.py b/src/modeling_skip.py index 18482c3..91b53e1 100644 --- a/src/modeling_skip.py +++ b/src/modeling_skip.py @@ -20,7 +20,7 @@ from transformers.utils.import_utils import is_torch_flex_attn_available from sparse_transformers import WeightCache, sparse_mlp_forward -from src.activation_capture import ActivationCaptureDefault +from src.activation_capture import ActivationCapture if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask @@ -352,7 +352,7 @@ def forward( def build_skip_connection_model_for_causal_lm(pretrained_model_class: type[PreTrainedModel], base_model_class: type[PreTrainedModel]): - ACTIVATION_CAPTURE = ActivationCaptureDefault + ACTIVATION_CAPTURE = ActivationCapture class SkipConnectionModelForCausalLM(pretrained_model_class, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/models/gemma3n/activation_capture_gemma.py b/src/models/gemma3n/activation_capture_gemma.py deleted file mode 100644 index 8364d6f..0000000 --- a/src/models/gemma3n/activation_capture_gemma.py +++ /dev/null @@ -1,11 +0,0 @@ -from src.activation_capture import ActivationCaptureDefault - - -class ActivationCaptureGemma3n(ActivationCaptureDefault): - """Helper class to capture activations from model layers.""" - - def _register_gate_hook(self, layer_idx, layer): - handle = layer.mlp.act_fn.register_forward_hook( - self._create_mlp_hook(layer_idx, 'gate') - ) - return handle diff --git a/src/models/gemma3n/modelling_gemma_skip.py b/src/models/gemma3n/modelling_gemma_skip.py index 7fcdd8a..cf0480f 100644 --- a/src/models/gemma3n/modelling_gemma_skip.py +++ b/src/models/gemma3n/modelling_gemma_skip.py @@ -27,7 +27,6 @@ from sparse_transformers import sparse_mlp_forward from src.models.gemma3n.configuration_gemma_skip import Gemma3nSkipConnectionConfig -from src.models.gemma3n.activation_capture_gemma import ActivationCaptureGemma3n from src.modeling_skip import SkipMLP, SkipDecoderLayer, build_skip_connection_model, build_skip_connection_model_for_causal_lm logger = logging.get_logger(__name__) @@ -413,7 +412,6 @@ def project_per_layer_inputs( Gemma3nSkipConnectionForCausalLMBase = build_skip_connection_model_for_causal_lm(Gemma3nSkipPreTrainedModel, Gemma3nSkipConnectionModel) class Gemma3nSkipConnectionForCausalLM(Gemma3nSkipConnectionForCausalLMBase): - ACTIVATION_CAPTURE = ActivationCaptureGemma3n _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/models/phi3/activation_capture_phi.py b/src/models/phi3/activation_capture_phi.py deleted file mode 100644 index e2e6cf9..0000000 --- a/src/models/phi3/activation_capture_phi.py +++ /dev/null @@ -1,58 +0,0 @@ -from src.activation_capture import ActivationCapture -import torch.nn.functional as F - - - -class ActivationCapturePhi3(ActivationCapture): - """Helper class to capture activations from model layers.""" - has_gate_proj: bool = True - has_up_proj: bool = True - - def get_layers(self, model): - return model.model.layers - - def _register_gate_hook(self, layer_idx, layer): - def hook(module, input, output): - key1 = f"{layer_idx}_{'gate'}" - key2 = f"{layer_idx}_{'up'}" - # Just detach, don't clone or move to CPU yet - gate_outputs, up_outputs = output.chunk(2, dim=1) - self.mlp_activations[key1] = gate_outputs.detach() - self.mlp_activations[key2] = up_outputs.detach() - return output - handle = layer.mlp.gate_up_proj.register_forward_hook(hook) - return handle - - def _register_up_hook(self, layer_idx, layer): - def hook(module, input, output): - key = f"{layer_idx}_{'up'}" - # Just detach, don't clone or move to CPU yet - up_outputs = output.chunk(2, dim=1)[1] - self.mlp_activations[key] = up_outputs.detach() - return output - handle = layer.mlp.gate_up_proj.register_forward_hook(hook) - return handle - - def get_gate_activations(self, layer_idx): - """Get combined MLP activations for a layer.""" - gate_key = f"{layer_idx}_gate" - if gate_key in self.mlp_activations: - gate_act = self.mlp_activations[gate_key] - return F.silu(gate_act) - return None - - def get_mlp_activations(self, layer_idx): - """Get combined MLP activations for a layer.""" - gate_key = f"{layer_idx}_gate" - up_key = f"{layer_idx}_up" - - if gate_key in self.mlp_activations and up_key in self.mlp_activations: - # Compute gated activations: gate(x) * up(x) - gate_act = self.mlp_activations[gate_key] - up_act = self.mlp_activations[up_key] - - # Apply SwiGLU activation: silu(gate) * up - gated_act = F.silu(gate_act) * up_act - return gated_act - - return None \ No newline at end of file diff --git a/src/models/phi3/modelling_phi_skip.py b/src/models/phi3/modelling_phi_skip.py index d4303c4..f3851b7 100644 --- a/src/models/phi3/modelling_phi_skip.py +++ b/src/models/phi3/modelling_phi_skip.py @@ -25,7 +25,6 @@ from src.models.phi3.configuration_phi_skip import Phi3SkipConnectionConfig from src.modeling_skip import SkipMLP, SkipDecoderLayer, FastLoRAProjection, build_skip_connection_model, build_skip_connection_model_for_causal_lm -from .activation_capture_phi import ActivationCapturePhi3 logger = logging.get_logger(__name__) @@ -339,8 +338,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Phi3SkipConnectionForCausalLMBase = build_skip_connection_model_for_causal_lm(Phi3SkipPreTrainedModel, Phi3SkipConnectionModel) class Phi3SkipConnectionForCausalLM(Phi3SkipConnectionForCausalLMBase): - ACTIVATION_CAPTURE = ActivationCapturePhi3 - _keys_to_ignore_on_load_missing = [ "model.layers.*.mlp.combined_proj_buffer", "model.layers.*.mlp.down_proj_buffer",