Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions downstream_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def parse_args():
help="Size of lora predictors to use as percentage of total hidden size")
parser.add_argument("--sp_layers", default="all", nargs='+',
help="Which layers to use sparse predictors for")
parser.add_argument("--disable_weight_cache", action="store_true",
help="Disable weight cache and compute sparse mlp manually")
return parser.parse_args()


Expand All @@ -57,6 +59,8 @@ def main():
args.sp_layers = [int(x) for x in args.sp_layers]
config.lora_size = args.lora_size / 100.0
config.sp_layers = args.sp_layers
if args.disable_weight_cache:
config.use_weight_cache = False
model = AutoModelForCausalLM.from_pretrained(config._name_or_path, config=config)
for layer_idx in model.get_decoder().sp_layers:
layer = model.get_decoder().layers[layer_idx]
Expand Down
40 changes: 28 additions & 12 deletions src/modeling_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers.processing_utils import Unpack
from transformers.utils import logging
from transformers.utils.import_utils import is_torch_flex_attn_available
from transformers.activations import ACT2FN

from sparse_transformers import WeightCache, sparse_mlp_forward
from src.activation_capture import ActivationCapture
Expand Down Expand Up @@ -52,29 +53,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.up(self.down(x))

class SkipMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, sparsity: float, bias: bool = False, act_fn="silu"):
def __init__(self, hidden_size: int, intermediate_size: int, sparsity: float, bias: bool = False, act_fn="silu", use_weight_cache=True):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
self.sparsity = sparsity
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.act_fn = act_fn
self.act_fn_name = act_fn
self.act_fn = ACT2FN[act_fn]
self.use_weight_cache = use_weight_cache

# Initialize mask but defer WeightCache creation until post_init
self.init_mask = torch.ones(intermediate_size, dtype=torch.bool)
#self.init_mask[int(intermediate_size * (1-sparsity)):] = 0

self.weight_cache : Optional[WeightCache] = None

if not self.use_weight_cache:
self.weight_mask = None

# Register buffers - start with reasonable size and ensure they can be resized
self.register_buffer('down_proj_buffer', torch.zeros(1, hidden_size, requires_grad=False))
self.register_buffer('combined_proj_buffer', torch.zeros(1, 2 * int(intermediate_size * sparsity), requires_grad=False))

def initialize_weight_cache(self):
"""Tie weights after weights are loaded (called from post_init)."""
if self.weight_cache is None:
if self.weight_cache is None and self.use_weight_cache:
# Create and initialize weight cache
self.weight_cache = WeightCache(
self.init_mask,
Expand All @@ -96,14 +102,20 @@ def to(self, *args, **kwargs):
return result

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = sparse_mlp_forward(
x.detach(),
self.weight_cache.get_concat_weight(), # type: ignore
self.weight_cache.get_active_down_weight(), # type: ignore
self.down_proj_buffer,
self.combined_proj_buffer,
self.act_fn
)
if self.use_weight_cache:
out = sparse_mlp_forward(
x.detach(),
self.weight_cache.get_concat_weight(), # type: ignore
self.weight_cache.get_active_down_weight(), # type: ignore
self.down_proj_buffer,
self.combined_proj_buffer,
self.act_fn_name
)
else:
# This should be replaced by a proper sparse implementation if we want to use this for anything other than simple debugging
up = self.act_fn(self.gate_proj(x) * self.up_proj(x) * self.weight_mask)
out = self.down_proj(up)
self.weight_mask = None
return out


Expand All @@ -115,6 +127,7 @@ def __init__(self, config: PretrainedConfig, layer_idx: int):
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.sparsity = config.sparsity
self.use_weight_cache = getattr(config, 'use_weight_cache', True)

self._init_components(config, layer_idx)

Expand Down Expand Up @@ -163,7 +176,10 @@ def weight_cache(self):
def _compute_binary_mask(self, hidden_states):
lora_proj_scores = self.mlp_lora_proj(hidden_states.view(-1, hidden_states.shape[-1]))
binary_mask = (lora_proj_scores >= 0).bool()
self.weight_cache.update_active_weights(binary_mask.any(dim=0)) # type: ignore
if self.use_weight_cache:
self.weight_cache.update_active_weights(binary_mask.any(dim=0)) # type: ignore
else:
self.mlp.weight_mask = binary_mask

def forward(
self,
Expand Down
3 changes: 2 additions & 1 deletion src/models/llama/modelling_llama_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def _set_mlp_inference(self, config, layer_idx):
config.intermediate_size,
config.sparsity,
config.mlp_bias,
config.hidden_act
config.hidden_act,
getattr(config, 'use_weight_cache', True)
)


Expand Down