From b4bc7d13f0763f09c89ced5dca891acb02305957 Mon Sep 17 00:00:00 2001 From: Kira Selby Date: Mon, 28 Jul 2025 08:57:08 -0400 Subject: [PATCH 1/2] Add flag to disable weight cache and compute sparsity without union over batch dimension Signed-off-by: Kira Selby --- downstream_eval.py | 4 +++ src/modeling_skip.py | 40 +++++++++++++++++------- src/models/llama/modelling_llama_skip.py | 3 +- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/downstream_eval.py b/downstream_eval.py index 6f60872..240ff5b 100644 --- a/downstream_eval.py +++ b/downstream_eval.py @@ -32,6 +32,8 @@ def parse_args(): help="Size of lora predictors to use as percentage of total hidden size") parser.add_argument("--sp_layers", default="all", nargs='+', help="Which layers to use sparse predictors for") + parser.add_argument("--disable_weight_cache", action="store_true", + help="Disable weight cache and compute sparse mlp manually") return parser.parse_args() @@ -57,6 +59,8 @@ def main(): args.sp_layers = [int(x) for x in args.sp_layers] config.lora_size = args.lora_size / 100.0 config.sp_layers = args.sp_layers + if args.disable_weight_cache: + config.use_weight_cache = False model = AutoModelForCausalLM.from_pretrained(config._name_or_path, config=config) for layer_idx in model.get_decoder().sp_layers: layer = model.get_decoder().layers[layer_idx] diff --git a/src/modeling_skip.py b/src/modeling_skip.py index 4c15d14..b8f1259 100644 --- a/src/modeling_skip.py +++ b/src/modeling_skip.py @@ -18,6 +18,7 @@ from transformers.processing_utils import Unpack from transformers.utils import logging from transformers.utils.import_utils import is_torch_flex_attn_available +from transformers.activations import ACT2FN from sparse_transformers import WeightCache, sparse_mlp_forward from src.activation_capture import ActivationCapture @@ -52,7 +53,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.up(self.down(x)) class SkipMLP(nn.Module): - def __init__(self, hidden_size: int, intermediate_size: int, sparsity: float, bias: bool = False, act_fn="silu"): + def __init__(self, hidden_size: int, intermediate_size: int, sparsity: float, bias: bool = False, act_fn="silu", use_weight_cache=True): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) @@ -60,7 +61,9 @@ def __init__(self, hidden_size: int, intermediate_size: int, sparsity: float, bi self.sparsity = sparsity self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.act_fn = act_fn + self.act_fn_name = act_fn + self.act_fn = ACT2FN[act_fn] + self.use_weight_cache = use_weight_cache # Initialize mask but defer WeightCache creation until post_init self.init_mask = torch.ones(intermediate_size, dtype=torch.bool) @@ -68,13 +71,16 @@ def __init__(self, hidden_size: int, intermediate_size: int, sparsity: float, bi self.weight_cache : Optional[WeightCache] = None + if not self.use_weight_cache: + self.weight_mask = None + # Register buffers - start with reasonable size and ensure they can be resized self.register_buffer('down_proj_buffer', torch.zeros(1, hidden_size, requires_grad=False)) self.register_buffer('combined_proj_buffer', torch.zeros(1, 2 * int(intermediate_size * sparsity), requires_grad=False)) def initialize_weight_cache(self): """Tie weights after weights are loaded (called from post_init).""" - if self.weight_cache is None: + if self.weight_cache is None and self.use_weight_cache: # Create and initialize weight cache self.weight_cache = WeightCache( self.init_mask, @@ -96,14 +102,20 @@ def to(self, *args, **kwargs): return result def forward(self, x: torch.Tensor) -> torch.Tensor: - out = sparse_mlp_forward( - x.detach(), - self.weight_cache.get_concat_weight(), # type: ignore - self.weight_cache.get_active_down_weight(), # type: ignore - self.down_proj_buffer, - self.combined_proj_buffer, - self.act_fn - ) + if self.use_weight_cache: + out = sparse_mlp_forward( + x.detach(), + self.weight_cache.get_concat_weight(), # type: ignore + self.weight_cache.get_active_down_weight(), # type: ignore + self.down_proj_buffer, + self.combined_proj_buffer, + self.act_fn_name + ) + else: + # This should be replaced by a proper sparse implementation if we want to use this for anything other than simple debugging + up = self.act_fn(self.gate_proj(x) * self.up_proj(x) * self.weight_mask) + out = self.down_proj(up) + self.weight_mask = None return out @@ -115,6 +127,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): self.hidden_size = config.hidden_size self.layer_idx = layer_idx self.sparsity = config.sparsity + self.use_weight_cache = config.use_weight_cache self._init_components(config, layer_idx) @@ -163,7 +176,10 @@ def weight_cache(self): def _compute_binary_mask(self, hidden_states): lora_proj_scores = self.mlp_lora_proj(hidden_states.view(-1, hidden_states.shape[-1])) binary_mask = (lora_proj_scores >= 0).bool() - self.weight_cache.update_active_weights(binary_mask.any(dim=0)) # type: ignore + if self.use_weight_cache: + self.weight_cache.update_active_weights(binary_mask.any(dim=0)) # type: ignore + else: + self.mlp.weight_mask = binary_mask def forward( self, diff --git a/src/models/llama/modelling_llama_skip.py b/src/models/llama/modelling_llama_skip.py index 7462da7..f270002 100644 --- a/src/models/llama/modelling_llama_skip.py +++ b/src/models/llama/modelling_llama_skip.py @@ -61,7 +61,8 @@ def _set_mlp_inference(self, config, layer_idx): config.intermediate_size, config.sparsity, config.mlp_bias, - config.hidden_act + config.hidden_act, + config.use_weight_cache ) From a8c5fc4e2410060640e77a7b24e1f236cb24eae8 Mon Sep 17 00:00:00 2001 From: Kira Selby Date: Mon, 28 Jul 2025 09:37:07 -0400 Subject: [PATCH 2/2] Set default value of use_weight_cache to true if not found in config Signed-off-by: Kira Selby --- src/modeling_skip.py | 2 +- src/models/llama/modelling_llama_skip.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modeling_skip.py b/src/modeling_skip.py index b8f1259..28562d9 100644 --- a/src/modeling_skip.py +++ b/src/modeling_skip.py @@ -127,7 +127,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int): self.hidden_size = config.hidden_size self.layer_idx = layer_idx self.sparsity = config.sparsity - self.use_weight_cache = config.use_weight_cache + self.use_weight_cache = getattr(config, 'use_weight_cache', True) self._init_components(config, layer_idx) diff --git a/src/models/llama/modelling_llama_skip.py b/src/models/llama/modelling_llama_skip.py index f270002..418fcaa 100644 --- a/src/models/llama/modelling_llama_skip.py +++ b/src/models/llama/modelling_llama_skip.py @@ -62,7 +62,7 @@ def _set_mlp_inference(self, config, layer_idx): config.sparsity, config.mlp_bias, config.hidden_act, - config.use_weight_cache + getattr(config, 'use_weight_cache', True) )