From 694c3172c02acdc992b043dfa14a3d5832879d49 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 3 Nov 2025 21:05:23 +0100 Subject: [PATCH 01/49] Add decilm modelling code Signed-off-by: Daniel Korzekwa --- .pre-commit-config.yaml | 11 + .../converters/convert_llama3_to_decilm.py | 3 +- .../decilm/deci_lm_hf_code/block_config.py | 308 ++ .../deci_lm_hf_code/configuration_decilm.py | 210 ++ .../transformers_4_51_3__cache_utils.py | 2537 +++++++++++++++++ 5 files changed, 3068 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/configuration_decilm.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 598957f86..4cf8ddd34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,18 @@ repos: hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] + # See: commit hooks modifies block_config.py leading to test_compress.py failing (#25) · Issues · omniml / modelopt · GitLab + exclude: > + (?x)^( + modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config\.py| + modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils\.py + )$ - id: ruff-format + exclude: > + (?x)^( + modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config\.py| + modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils\.py + )$ - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.17.1 diff --git a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py index d17e7ef74..96b96f351 100644 --- a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py +++ b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py @@ -24,9 +24,10 @@ from puzzle_tools.checkpoint_utils import copy_tokenizer from puzzle_tools.checkpoint_utils_hf import copy_deci_lm_hf_code from puzzle_tools.conversion_utils import convert_model_weights_to_decilm -from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig from transformers import LlamaConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig + """ example: diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config.py new file mode 100644 index 000000000..d5eebfa35 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config.py @@ -0,0 +1,308 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +import dataclasses +import inspect +import warnings +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Optional, Type, Union, get_args, get_origin + + +@dataclass(frozen=True, kw_only=True) +class BaseDataclass: + """ + A dataclass base class with several utilities: + 1. Comparison via string representation. + 2. Initialization of dataclasses fields from dicts. + 3. Setting attributes even though it's frozen (but only inside __post_init__!) + """ + + def __eq__(self, other: "BaseDataclass") -> bool: + return str(self) == str(other) + + def __hash__(self) -> int: + return hash(str(self)) + + def __lt__(self, other: "BaseDataclass") -> bool: + return str(self) < str(other) + + def _force_setattr(self, name: str, value: Any) -> None: + """ + Set an attribute even in frozen dataclasses. + Use only inside __post_init__! + """ + assert _is_called_from_post_init(), ( + "_force_setattr should only be called from __post_init__, " + "if you need to change an attribute use dataclasses.replace " + "or create a new instance :)" + ) + object.__setattr__(self, name, value) + + def __post_init__(self): + """ + Init dataclass fields from dicts + """ + for field in dataclasses.fields(self): + field_dict = getattr(self, field.name) + if isinstance(field_dict, dict) and _is_dataclass_type(field.type): + dataclass_cls = _get_dataclass_type(field.type) + sub_fields = [field.name for field in dataclasses.fields(dataclass_cls)] + unsupported_fields = [ + field_name for field_name in field_dict.keys() if field_name not in sub_fields + ] + if len(unsupported_fields) > 0: + warnings.warn( + f"Removed unsupported fields {unsupported_fields} from {dataclass_cls}" + ) + + field_dict = {k: v for k, v in field_dict.items() if k not in unsupported_fields} + self._force_setattr(field.name, dataclass_cls(**field_dict)) + + +def _is_called_from_post_init() -> bool: + frame = inspect.currentframe() + while frame: + if frame.f_code.co_name == "__post_init__": + return True + frame = frame.f_back + return False + + +def _is_dataclass_type(tp: Type) -> bool: + """ + Like dataclasses.is_dataclass but also works for Optional[] and Union[] of a dataclass type + """ + try: + _get_dataclass_type(tp) + return True + except: + return False + + +def _get_dataclass_type(tp: Type) -> dataclass: + """ + If the given type is a dataclass, the function returns it. + If it is a Union[] or Optional[], the function extracts the first dataclass type. + If no dataclass type is found, the function raises a ValueError. + """ + origin = get_origin(tp) + if origin is Union: + for type_in_union in get_args(tp): + if dataclasses.is_dataclass(type_in_union): + return type_in_union + if dataclasses.is_dataclass(tp): + return tp + raise ValueError("Not a dataclass") + + +@dataclass(frozen=True, kw_only=True) +class SubblockConfig(BaseDataclass): + no_op: bool = False + replace_with_linear: bool = False + sparsify: Optional[list[str]] = None + weights_precision: Optional[str] = "bf16" + + def __post_init__(self): + super().__post_init__() + assert not (self.no_op and self.replace_with_linear) + if self.no_op: + self._force_setattr("sparsify", None) + + @abstractmethod + def to_blockconfig(self) -> "BlockConfig": + """ " + Convert to a block including this subblock only. + """ + ... + + +@dataclass(frozen=True, kw_only=True) +class MoEConfig(BaseDataclass): + """ + Configuration class for Mixture of Experts parameters. + """ + + num_local_experts: int = 8 + num_experts_per_tok: int = 1 + expert_intermediate_dim: int = 8192 + shared_expert_intermediate_dim: int = 8192 + # router_aux_loss_coef: float = 0.01 + # router_z_loss_coef: float = 0.0 # Optional z-loss coefficient + + def __post_init__(self): + # Validate the configuration + if self.num_local_experts <= 0: + raise ValueError(f"num_local_experts must be positive, got {self.num_local_experts}") + if self.num_experts_per_tok <= 0: + raise ValueError(f"top_k must be positive, got {self.top_k}") + if self.num_experts_per_tok > self.num_local_experts: + raise ValueError( + f"top_k ({self.top_k}) cannot be greater than num_local_experts ({self.num_local_experts})" + ) + # if self.router_aux_loss_coef < 0: + # raise ValueError(f"router_aux_loss_coef must be non-negative, got {self.router_aux_loss_coef}") + + +@dataclass(frozen=True, kw_only=True) +class MambaConfig(BaseDataclass): + state_dim: int + num_heads: int + head_dim: int + num_groups: int + + +@dataclass(frozen=True, kw_only=True) +class Llama4AttentionConfig(BaseDataclass): + attention_chunk_size: Optional[int] = None + use_rope: Optional[bool] = None + use_qk_norm: Optional[bool] = None + attn_scale: Optional[float] = None + floor_scale: Optional[float] = None + attn_temperature_tuning: Optional[bool] = None + attention_dropout: Optional[float] = None + + +@dataclass(frozen=True, kw_only=True) +class AttentionConfig(SubblockConfig): + n_heads_in_group: Optional[int] = None + window_length: Optional[int] = None + num_sink_tokens: Optional[int] = None + use_prefill_window_in_sink_attention: bool = False + unshifted_sink: bool = False + mamba: Optional[MambaConfig] = None + llama4: Optional[Llama4AttentionConfig] = None + + def __post_init__(self): + super().__post_init__() + + if self.no_op: + assert not self.replace_with_linear + assert not self.is_mamba + assert not self.is_llama4 + + if self.no_op or self.replace_with_linear or self.is_mamba: + for irrelevant_att in [ + "n_heads_in_group", + "window_length", + "num_sink_tokens", + "use_prefill_window_in_sink_attention", + "unshifted_sink", + "attention_chunk_size", + "attn_scale", + "floor_scale", + "attn_temperature_tuning", + "attention_dropout", + "use_qk_norm", + ]: + self._force_setattr(irrelevant_att, None) + else: + assert self.n_heads_in_group is not None + + if self.is_sink: + assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), ( + "Unshifted sink uses its own kind of explicit masking, not standard window. " + "Set use_prefill_window_in_sink_attention to False." + ) + assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), ( + "Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True" + ) + + if self.is_llama4: + assert not self.is_sink, "Sink not support with Llama4 currently" + assert not self.is_sliding, "Sliding window not support with Llama4 currently" + assert not self.unshifted_sink, "Unshifted sink not support with Llama4 currently" + + def to_blockconfig(self) -> "BlockConfig": + return BlockConfig(attention=self, ffn=FFNConfig(no_op=True)) + + @property + def prefill_sliding_window(self) -> Optional[int]: + if self.window_length is not None: + if not self.is_sink or self.use_prefill_window_in_sink_attention: + return self.window_length + return None + + @property + def is_sliding(self) -> bool: + return self.prefill_sliding_window is not None + + @property + def is_sink(self) -> bool: + return (self.window_length is not None) and (self.num_sink_tokens is not None) + + @property + def is_mamba(self) -> bool: + return self.mamba is not None + + @property + def is_llama4(self) -> bool: + return self.llama4 is not None + + +@dataclass(frozen=True, kw_only=True) +class FFNConfig(SubblockConfig): + gated: Optional[bool] = ( + True # Gated Linear Unit e.g. SwiGLU or vanilla MLP (up -> activation -> down) + ) + hidden_act: Optional[str] = "silu" + moe: Optional[MoEConfig] = None + intermediate_size: Optional[int] = None + + def __post_init__(self): + super().__post_init__() + if self.no_op or self.replace_with_linear: + self._force_setattr("gated", None) + self._force_setattr("hidden_act", None) + self._force_setattr("moe", None) + self._force_setattr("intermediate_size", None) + elif self.is_moe: + self._force_setattr("gated", None) + self._force_setattr("hidden_act", None) + self._force_setattr("intermediate_size", None) + else: + assert self.intermediate_size is not None, ( + "Intermediate size must be provided for an FFN block" + ) + assert self.intermediate_size % 256 == 0, "Intermediate size must be divisible by 256" + + def to_blockconfig(self) -> "BlockConfig": + return BlockConfig(attention=AttentionConfig(no_op=True), ffn=self) + + @property + def is_moe(self) -> bool: + return self.moe is not None + + +SUBBLOCK_CLS_DICT = { + "attention": AttentionConfig, + "ffn": FFNConfig, +} + + +@dataclass(frozen=True, kw_only=True) +class BlockConfig(BaseDataclass): + attention: Optional[AttentionConfig] = None + ffn: Optional[FFNConfig] = None + parallel_blocks: Optional[list["BlockConfig"]] = None + + def __post_init__(self): + super().__post_init__() + if (self.parallel_blocks is not None) and isinstance(self.parallel_blocks[0], dict): + initialized_block_configs = [ + BlockConfig(**block_config) for block_config in self.parallel_blocks + ] + self._force_setattr("parallel_blocks", initialized_block_configs) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/configuration_decilm.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/configuration_decilm.py new file mode 100644 index 000000000..c37b9adaf --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/configuration_decilm.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors + +import copy +import dataclasses +import warnings +from typing import Any + +from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available + +from .block_config import BlockConfig +from .transformers_4_44_2__configuration_llama import LlamaConfig + +# fakes imports to make AutoConfig infer dependencies +from .transformers_4_44_2__modeling_rope_utils import rope_config_validation +from .transformers_4_51_3__cache_utils import HybridChunkedCache +from .transformers_4_51_3__configuration_llama4 import Llama4Config + +# make sure that auto-formatting doesn't remove the fake imports +rope_config_validation +Llama4Config +HybridChunkedCache + + +class DeciLMConfig(LlamaConfig): + model_type = "nemotron-nas" + + # Mapping from global attribute names to their per-layer equivalents in block_configs + # Format: 'global_name': ('block_section', 'layer_name') + PER_LAYER_ATTRIBUTE_MAPPING = { + "intermediate_size": ("ffn", "intermediate_size"), + "num_key_value_heads": ( + "attention", + "n_heads_in_group", + ), # Note: derived value (num_heads / num_kv_heads) + "hidden_act": ("ffn", "hidden_act"), + "sliding_window": ("attention", "window_length"), # Note: different name! + } + + def __init__( + self, + block_configs: list[dict] | list[BlockConfig] | None = None, + position_embedding_type: str = "rope", + llama4_attn_implementation: str | None = None, + block_return_only_hidden_states: bool = False, + router_aux_loss_coef: float = 0.01, + router_z_loss_coef: float = 0.0, + output_router_logits: bool = False, + head_dim: int | None = 128, + o_proj_bias: bool = False, + **kwargs, + ): + self.block_configs: list[BlockConfig] = block_configs + if self.block_configs is not None: + if isinstance(self.block_configs[0], dict): + self.block_configs = [BlockConfig(**conf) for conf in self.block_configs] + + assert position_embedding_type in ["rope", "rope_llama4", "none", "mistral_yarn"] + self.position_embedding_type = position_embedding_type + if self.position_embedding_type == "none": + self.rope_theta = None + self.rope_scaling = None + + self.block_return_only_hidden_states = block_return_only_hidden_states + self.router_aux_loss_coef = router_aux_loss_coef + self.router_z_loss_coef = router_z_loss_coef + self.output_router_logits = output_router_logits + self.o_proj_bias = o_proj_bias + + self._choose_llama4_attn_implementation(llama4_attn_implementation) + attn_implementation = self._choose_llama3_attn_implementation(kwargs) + super().__init__(attn_implementation=attn_implementation, **kwargs) + self.head_dim = ( + head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + ) + + # Delete per-layer attributes after parent init (they should only exist in block_configs) + self._delete_per_layer_attributes() + + if self.block_configs is not None: + assert len(self.block_configs) == self.num_hidden_layers + + def _delete_per_layer_attributes(self): + """Delete per-layer attributes that should only exist in block_configs. + + These attributes are intentionally deleted AFTER super().__init__() to ensure + they don't exist at the global config level. Deleting them (rather than setting + to None) makes it clear they shouldn't be accessed globally. + """ + present_attrs = { + attr: getattr(self, attr) + for attr in self.PER_LAYER_ATTRIBUTE_MAPPING + if hasattr(self, attr) + } + if present_attrs: + warnings.warn( + f"Deleting global per-layer attributes (should only be in block_configs): {present_attrs}", + UserWarning, + stacklevel=3, + ) + for attr in self.PER_LAYER_ATTRIBUTE_MAPPING: + if hasattr(self, attr): + delattr(self, attr) + + def _choose_llama4_attn_implementation(self, llama4_attn_implementation): + self.llama4_attn_implementation = llama4_attn_implementation + if self.llama4_attn_implementation is None: + if is_torch_sdpa_available(): + _print_once("auto-setting llama4_attn_implementation to sdpa") + self.llama4_attn_implementation = "sdpa" + else: + _print_once("auto-setting llama4_attn_implementation to eager") + self.llama4_attn_implementation = "eager" + + def _choose_llama3_attn_implementation(self, kwargs: dict[str, Any]) -> str: + attn_implementation = kwargs.pop("attn_implementation", None) + if attn_implementation is None and is_flash_attn_2_available(): + _print_once("auto-setting attn_implementation (for Llama3 layers) to flash_attention_2") + attn_implementation = "flash_attention_2" + + if self.block_configs is not None: + using_unshifted_sink = any( + block_config.attention.unshifted_sink for block_config in self.block_configs + ) + if using_unshifted_sink and attn_implementation != "eager": + warnings.warn( + "Forcing attn_implementation='eager' since some attention layers use unshifted sink" + ) + attn_implementation = "eager" + return attn_implementation + + def to_dict(self) -> dict[str, Any]: + """Convert config to dictionary, removing per-layer-only attributes.""" + self_dict = super().to_dict() + if self.block_configs is not None: + self_dict["block_configs"] = [dataclasses.asdict(conf) for conf in self.block_configs] + + # Remove global keys that should only exist per-layer in block_configs + for key in self.PER_LAYER_ATTRIBUTE_MAPPING: + self_dict.pop(key, None) + + return self_dict + + def set_block_configs(self, block_configs: list[BlockConfig]) -> "DeciLMConfig": + new_model_config = copy.deepcopy(self) + new_model_config.block_configs = block_configs + new_model_config.num_hidden_layers = len(block_configs) + return new_model_config + + def get_num_hidden_layers(self) -> int: + return self.num_hidden_layers + + def get_hidden_size(self) -> int: + return self.hidden_size + + def get_embedding_layer_name(self) -> str: + return "model.embed_tokens" + + def get_final_layer_norm_layer_name(self) -> str: + return "model.norm" + + def get_lm_head_layer_name(self) -> str: + return "lm_head" + + def get_layers_layer_name(self) -> str: + return "model.layers" + + def get_block_config(self, layer_idx: int | tuple[int, ...]) -> BlockConfig: + if isinstance(layer_idx, tuple) and len(layer_idx) == 1: + layer_idx = layer_idx[0] + + if isinstance(layer_idx, int): + return self.block_configs[layer_idx] + + external_layer_idx, internal_layer_idx = layer_idx + return self.block_configs[external_layer_idx].parallel_blocks[internal_layer_idx] + + def get_min_attention_chunk_size(self) -> int | None: + min_chunk_size = float("inf") + for block_config in self.block_configs: + if block_config.attention.llama4 is not None: + attention_chunk_size = block_config.attention.llama4.attention_chunk_size + if attention_chunk_size is not None: + min_chunk_size = min(min_chunk_size, attention_chunk_size) + + if min_chunk_size == float("inf"): + return None + return min_chunk_size + + +def _print_once(message: str): + if not hasattr(_print_once, "was_printed"): + _print_once.was_printed = set() + if message not in _print_once.was_printed: + _print_once.was_printed.add(message) + print(message) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py new file mode 100644 index 000000000..e872a87d2 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py @@ -0,0 +1,2537 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa +# mypy: ignore-errors +import copy +import importlib.metadata +import json +import os +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any + +import torch +from packaging import version +from transformers.configuration_utils import PretrainedConfig +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6 +from transformers.utils import ( + is_hqq_available, + is_optimum_quanto_available, + is_torch_greater_or_equal, + logging, +) + +if is_hqq_available(): + from hqq.core.quantize import Quantizer as HQQQuantizer + +logger = logging.get_logger(__name__) + + +class Cache: + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + is_compileable = False + + def __init__(self): + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + def get_max_cache_shape(self) -> int | None: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: int | None = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) + if self.value_cache[layer_idx].numel(): + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) + + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " + "model input instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + + +@dataclass +class CacheConfig: + """ + Base class for cache configs + """ + + cache_implementation: None + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a CacheConfig instance from a dictionary of parameters. + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + + Returns: + CacheConfig: Instance of CacheConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: str | os.PathLike): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +@dataclass +class QuantizedCacheConfig(CacheConfig): + """ + Configuration class for quantized cache settings. + + Attributes: + backend (`str`, *optional*, defaults to `"quanto"`): + Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] + nbits (`Optional[int]`, *optional*, defaults to 4): + Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. + axis_key (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + axis_value (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + q_group_size (`Optional[int]`, *optional*, defaults to 64): + Size of the quantization group, should be a divisor of the model's hidden dimension. + Defaults to 64. + residual_length (`Optional[int]`, *optional*, defaults to 128): + Length of the residual cache which will always be stored in original precision. + Defaults to 128. + compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. + device (`str`, *optional*, defaults to `"cpu"`): + Device on which to perform computations, should be same as the model's device. + """ + + def __init__( + self, + backend: str = "quanto", + nbits: int | None = 4, + axis_key: int | None = 0, + axis_value: int | None = 0, + q_group_size: int | None = 64, + residual_length: int | None = 128, + compute_dtype: torch.dtype | None = torch.float16, + device: str | None = "cpu", + ): + self.backend = backend + self.nbits = nbits + self.axis_key = axis_key + self.axis_value = axis_value + self.q_group_size = q_group_size + self.residual_length = residual_length + self.compute_dtype = compute_dtype + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + # Check that the values are reasonable in general (nbits, axis) + # Later in QuantizedCache init we check if they are supported for that particular backend + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + incorrect_arg_msg.format( + key="nbits", + correct_value="2 or 4 or 8", + found_value=self.nbits, + ), + ) + if self.q_group_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="q_group_size", + correct_value="a positive integer", + found_value=self.q_group_size, + ), + ) + if self.residual_length < 0: + raise ValueError( + incorrect_arg_msg.format( + key="residual_length", + correct_value="a positive integer", + found_value=self.residual_length, + ), + ) + + if self.axis_key not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_key", + correct_value="`1` or `0`, `-1`", + found_value=self.axis_key, + ), + ) + + if self.axis_value not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_value", + correct_value="`1` or `0` or `-1`", + found_value=self.axis_value, + ), + ) + + +@dataclass +class StaticCacheConfig(CacheConfig): + """ + Configuration class for static cache settings. + """ + + cache_implementation = "static" + + def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): + self.batch_size = batch_size + self.max_cache_len = max_cache_len + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + + if self.batch_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="batch_size", + correct_value="> 0", + found_value=self.batch_size, + ), + ) + + if self.max_cache_len <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="max_cache_len", + correct_value="> 0", + found_value=self.max_cache_len, + ), + ) + + +class DynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicCache() + ``` + """ + + def __init__(self, _distributed_cache_data: Iterable = None) -> None: + super().__init__() + self._seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121 + # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the + # iterable contains the key and value states for a layer gathered across replicas by torch.distributed + # (shape=[global batch size, num_heads, seq_len, head_dim]). + # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break + # compatibility. The name of the argument doesn't matter. + if _distributed_cache_data is not None: + for key_states, value_states in _distributed_cache_data: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if key_states is not None: + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif ( + not self.key_cache[ + layer_idx + ].numel() # prefers not t.numel() to len(t) == 0 to export the model + ): # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx], key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) + <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def get_max_cache_shape(self) -> int | None: + """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: tuple[tuple[torch.FloatTensor]] | None = None + ) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def crop(self, max_length: int): + """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" + # In case it is negative + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + if self.key_cache[idx].numel(): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + def batch_split(self, full_batch_size: int, split_size: int) -> list["DynamicCache"]: + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: list["DynamicCache"]) -> "DynamicCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls() + for idx in range(len(splits[0])): + key_cache = [ + current.key_cache[idx] for current in splits if current.key_cache[idx].numel() + ] + value_cache = [ + current.value_cache[idx] for current in splits if current.value_cache[idx].numel() + ] + if key_cache != []: + layer_keys = torch.cat(key_cache, dim=0) + layer_values = torch.cat(value_cache, dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave( + repeats, dim=0 + ) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + + +# Utilities for `DynamicCache` <> torch.export support +def _flatten_dynamic_cache( + dynamic_cache: DynamicCache, +): + """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" + if not isinstance(dynamic_cache, DynamicCache): + raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") + + if not is_torch_greater_or_equal_than_2_6: + logger.warning_once( + "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." + ) + + # NOTE it seems _seen_tokens is deprecated, so probably doesn't need tracking + dictionary = { + "key_cache": getattr(dynamic_cache, "key_cache"), + "value_cache": getattr(dynamic_cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten(dictionary) + + +def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): + dictionary = { + "key_cache": getattr(dynamic_cache, "key_cache"), + "value_cache": getattr(dynamic_cache, "value_cache"), + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + + +def _unflatten_dynamic_cache( + values, + context: torch.utils._pytree.Context, +): + dictionary = torch.utils._pytree._dict_unflatten(values, context) + cache = DynamicCache() + for k, v in dictionary.items(): + setattr(cache, k, v) + return cache + + +def _flatten_dynamic_cache_for_fx(cache, spec): + dictionary = { + "key_cache": getattr(cache, "key_cache"), + "value_cache": getattr(cache, "value_cache"), + } + return torch.utils._pytree.tree_flatten(dictionary)[0] + + +if is_torch_greater_or_equal("2.3"): + torch.utils._pytree.register_pytree_node( + DynamicCache, + _flatten_dynamic_cache, + _unflatten_dynamic_cache, + serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", + flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, + ) + # TODO (tmanlaibaatar) This won't be needed in torch 2.7. + torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) + + +class OffloadedCache(DynamicCache): + """ + A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. + Useful for generating from models with very long context. + + In addition to the default accelerator stream, where all forward() computations happen, + this class uses another stream, the prefetch stream, which it creates itself. + Since scheduling of operations on separate streams happens independently, this class uses + the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. + The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to + ensure the eviction is scheduled after all computations on that cache are finished. + """ + + def __init__(self) -> None: + if not ( + torch.cuda.is_available() + or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) + ): + raise RuntimeError( + "OffloadedCache can only be used with a GPU" + + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") + ) + + super().__init__() + self.original_device = [] + self.prefetch_stream = None + self.prefetch_stream = ( + torch.Stream() + if is_torch_greater_or_equal("2.7", accept_dev=True) + else torch.cuda.Stream() + ) + self.beam_idx = None # used to delay beam search operations + + def prefetch_layer(self, layer_idx: int): + "Starts prefetching the next layer cache" + if layer_idx < len(self): + with ( + self.prefetch_stream + if is_torch_greater_or_equal("2.7", accept_dev=True) + else torch.cuda.stream(self.prefetch_stream) + ): + # Prefetch next layer tensors to GPU + device = self.original_device[layer_idx] + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to( + device, non_blocking=True + ) + + def evict_previous_layer(self, layer_idx: int): + "Moves the previous layer cache to the CPU" + if len(self) > 2: + # We do it on the default stream so it occurs after all earlier computations on these tensors are done + prev_layer_idx = (layer_idx - 1) % len(self) + self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to( + "cpu", non_blocking=True + ) + self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to( + "cpu", non_blocking=True + ) + + def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: + "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." + if layer_idx < len(self): + # Evict the previous layer if necessary + if is_torch_greater_or_equal("2.7", accept_dev=True): + torch.accelerator.current_stream().synchronize() + else: + torch.cuda.current_stream().synchronize() + self.evict_previous_layer(layer_idx) + # Load current layer cache to its original device if not already there + original_device = self.original_device[layer_idx] + self.prefetch_stream.synchronize() + key_tensor = self.key_cache[layer_idx] + value_tensor = self.value_cache[layer_idx] + # Now deal with beam search ops which were delayed + if self.beam_idx is not None: + self.beam_idx = self.beam_idx.to(original_device) + key_tensor = key_tensor.index_select(0, self.beam_idx) + value_tensor = value_tensor.index_select(0, self.beam_idx) + # Prefetch the next layer + self.prefetch_layer((layer_idx + 1) % len(self)) + return (key_tensor, value_tensor) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Saves the beam indices and reorders the cache when the tensor is back to its device.""" + # We delay this operation until the tensors are back to their original + # device because performing torch.index_select on the CPU is very slow + del self.beam_idx + self.beam_idx = beam_idx.clone() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) < layer_idx: + raise ValueError( + "OffloadedCache does not support model usage where layers are skipped. Use DynamicCache." + ) + elif len(self.key_cache) == layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.original_device.append(key_states.device) + self.evict_previous_layer(layer_idx) + else: + key_tensor, value_tensor = self[layer_idx] + self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError + # if a method is not supposed to be supported in a subclass we should set it to None + from_legacy_cache = None + + to_legacy_cache = None + + +class QuantizedCache(DynamicCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` + """ + + def __init__(self, cache_config: QuantizedCacheConfig) -> None: + super().__init__() + self._quantized_key_cache: list[torch.Tensor] = [] + self._quantized_value_cache: list[torch.Tensor] = [] + + self.nbits = cache_config.nbits + self.residual_length = cache_config.residual_length + self.q_group_size = cache_config.q_group_size + self.axis_key = cache_config.axis_key + self.axis_value = cache_config.axis_value + self.compute_dtype = cache_config.compute_dtype + self.device = cache_config.device + + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + if len(self.key_cache) < layer_idx: + raise ValueError( + "QuantizedCache does not support model usage where layers are skipped. Use DynamicCache." + ) + elif len(self.key_cache) == layer_idx: + self._quantized_key_cache.append( + self._quantize(key_states.contiguous(), axis=self.axis_key) + ) + self._quantized_value_cache.append( + self._quantize(value_states.contiguous(), axis=self.axis_value) + ) + self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) + self.value_cache.append( + torch.zeros(0, dtype=key_states.dtype, device=key_states.device) + ) + keys_to_return, values_to_return = key_states, value_states + else: + dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) + dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] + values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] + + keys_to_return = torch.cat(keys_to_return, dim=-2) + values_to_return = torch.cat(values_to_return, dim=-2) + if ( + self.key_cache[layer_idx].dim() == 4 + and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length + ): + self._quantized_key_cache[layer_idx] = self._quantize( + keys_to_return.contiguous(), axis=self.axis_key + ) + self._quantized_value_cache[layer_idx] = self._quantize( + values_to_return.contiguous(), axis=self.axis_value + ) + self.key_cache[layer_idx] = torch.zeros( + 0, dtype=key_states.dtype, device=key_states.device + ) + self.value_cache[layer_idx] = torch.zeros( + 0, dtype=key_states.dtype, device=key_states.device + ) + else: + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx], key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + return keys_to_return, values_to_return + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is + # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx + # this part of code otherwise fails when used to verify attn_weight shape in some models + return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 + + def _quantize(self, tensor, axis): + """Quantizes a key/value using a defined quantization method.""" + raise NotImplementedError("Make sure to implement `_quantize` in a subclass.") + + def _dequantize(self, q_tensor): + """Dequantizes back the tensor that was quantized by `self._quantize()`""" + raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.") + + +class QuantoQuantizedCache(QuantizedCache): + """ + Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + + Parameters: + cache_config (`QuantizedCacheConfig`): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install quanto first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4) + >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + QuantoQuantizedCache() + ``` + """ + + def __init__(self, cache_config: CacheConfig) -> None: + super().__init__(cache_config) + + if is_optimum_quanto_available(): + optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) + if optimum_quanto_version <= version.parse("0.2.5"): + raise ImportError( + f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." + ) + from optimum.quanto import MaxOptimizer, qint2, qint4 + + if self.nbits not in [2, 4]: + raise ValueError( + f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}" + ) + + if self.axis_key not in [0, -1]: + raise ValueError( + f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}" + ) + + if self.axis_value not in [0, -1]: + raise ValueError( + f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" + ) + + self.qtype = qint4 if self.nbits == 4 else qint2 + self.optimizer = ( + MaxOptimizer() + ) # hardcode as it's the only one for per-channel quantization + + def _quantize(self, tensor, axis): + # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore + if is_optimum_quanto_available(): + from optimum.quanto import quantize_weight + + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) + return qtensor + + def _dequantize(self, qtensor): + return qtensor.dequantize() + + +class HQQQuantizedCache(QuantizedCache): + """ + Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + + Parameters: + cache_config (`QuantizedCacheConfig`): + A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install hqq first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) + >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HQQQuantizedCache() + ``` + """ + + def __init__(self, cache_config: CacheConfig) -> None: + super().__init__(cache_config) + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" + ) + + if self.axis_key not in [0, 1]: + raise ValueError( + f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}" + ) + + if self.axis_value not in [0, 1]: + raise ValueError( + f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}" + ) + + self.quantizer = HQQQuantizer + + def _quantize(self, tensor, axis): + qtensor, meta = self.quantizer.quantize( + tensor, + axis=axis, + device=self.device, + compute_dtype=self.compute_dtype, + nbits=self.nbits, + group_size=self.q_group_size, + ) + meta["compute_dtype"] = self.compute_dtype + self.quantizer.cuda( + qtensor, meta=meta, device=self.device + ) # Move to device and cast to dtype + return qtensor, meta + + def _dequantize(self, qtensor): + quant_tensor, meta = qtensor + tensor = self.quantizer.dequantize(quant_tensor, meta) + return tensor + + +class SinkCache(Cache): + """ + A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to + generate beyond the length of its context window, without losing fluency in the conversation. As it discards past + tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Parameters: + window_length (`int`): + The length of the context window. + num_sink_tokens (`int`): + The number of sink tokens. See the original paper for more information. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + SinkCache() + ``` + """ + + is_sliding = True + + def __init__(self, window_length: int, num_sink_tokens: int) -> None: + super().__init__() + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + self.window_length = window_length + self.num_sink_tokens = num_sink_tokens + self.cos_sin_rerotation_cache = {} + self._cos_cache = None + self._sin_cache = None + self._seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + + @staticmethod + def _rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_key_rotary_pos_emb( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> torch.Tensor: + rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) + return rotated_key_states + + def _get_rerotation_cos_sin( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + if key_states.shape[-2] not in self.cos_sin_rerotation_cache: + # Upcast to float32 temporarily for better accuracy + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence + original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] + shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] + original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] + shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] + rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin + rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin + + self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( + rerotation_cos.to(key_states.dtype).unsqueeze(0), + rerotation_sin.to(key_states.dtype).unsqueeze(0), + ) + return self.cos_sin_rerotation_cache[key_states.shape[-2]] + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_cache_shape(self) -> int | None: + """Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length.""" + return self.window_length + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, + `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the + rotation as the tokens are shifted. + + Return: + A tuple containing the updated key and value states. + """ + # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models + # with partially rotated position embeddings, like Phi or Persimmon. + if cache_kwargs is None: + cache_kwargs = {} + sin = cache_kwargs.get("sin") + cos = cache_kwargs.get("cos") + partial_rotation_size = cache_kwargs.get("partial_rotation_size") + using_rope = cos is not None and sin is not None + + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the sin/cos cache, which holds sin/cos values for all possible positions + if using_rope and layer_idx == 0: + # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove + # after all RoPE models have a llama-like cache utilization. + if cos.dim() == 2: + self._cos_cache = cos + self._sin_cache = sin + elif self._cos_cache is None: + self._cos_cache = cos[0, ...] + self._sin_cache = sin[0, ...] + elif self._cos_cache.shape[0] < self.window_length: + self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) + self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) + + # [bsz, num_heads, seq_len, head_dim] + if len(self.key_cache) <= layer_idx: + # Empty cache + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: + # Growing cache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + else: + # Shifting cache + keys_to_keep = self.key_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : + ] + + # On RoPE models, we need to recompute the Key rotation as the tokens are shifted + if using_rope: + rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( + key_states, + self._cos_cache[: self.window_length], + self._sin_cache[: self.window_length], + ) + if partial_rotation_size is not None: + keys_to_keep, keys_pass = ( + keys_to_keep[..., :partial_rotation_size], + keys_to_keep[..., partial_rotation_size:], + ) + keys_to_keep = self._apply_key_rotary_pos_emb( + keys_to_keep, rerotation_cos, rerotation_sin + ) + if partial_rotation_size is not None: + keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) + + # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens + sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] + self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) + + sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] + values_to_keep = self.value_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : + ] + self.value_cache[layer_idx] = torch.cat( + [sink_values, values_to_keep, value_states], dim=-2 + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. If you are manually setting the batch size, make sure to take into account the + number of beams if you are running beam search + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + StaticCache() + ``` + """ + + is_compileable = True + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.max_cache_len = ( + config.max_position_embeddings if max_cache_len is None else max_cache_len + ) + + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // config.num_attention_heads + ) + + self._dtype = dtype + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + # Note: There will be significant perf decrease if switching to use 5D tensors instead. + cache_shape = ( + self.max_batch_size, + self.num_key_value_heads, + self.max_cache_len, + self.head_dim, + ) + device = torch.device(device) if device is not None else None + for idx in range(config.num_hidden_layers): + if layer_device_map is not None: + layer_device = layer_device_map[idx] + else: + layer_device = device + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + if cache_kwargs is None: + cache_kwargs = {} + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class SlidingWindowCache(StaticCache): + """ + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, + if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), + we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. + + The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + + indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + + We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + SlidingWindowCache() + ``` + """ + + is_sliding = True + is_compileable = True + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + max_cache_len = min(config.sliding_window, max_cache_len) + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + layer_device_map=layer_device_map, + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if cache_kwargs is None: + cache_kwargs = {} + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) + if cache_position.shape[0] > self.max_cache_len: + k_out = key_states[:, :, -self.max_cache_len :, :] + v_out = value_states[:, :, -self.max_cache_len :, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones( + self.max_cache_len, dtype=torch.long, device=value_states.device + ).cumsum(0) + cache_position = cache_position.clamp(0, self.max_cache_len - 1) + to_shift = cache_position >= self.max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len + + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + + return k_out, v_out + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def reset(self): + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class EncoderDecoderCache(Cache): + """ + Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and + cross-attention caches. + + Example: + + ```python + >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") + >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") + + >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") + + >>> # Prepare cache classes for encoder and decoder and pass it to model's forward + >>> self_attention_cache = DynamicCache() + >>> cross_attention_cache = DynamicCache() + >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + EncoderDecoderCache() + ``` + + """ + + def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): + super().__init__() + self.self_attention_cache = self_attention_cache + self.cross_attention_cache = cross_attention_cache + self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) + + self.is_updated = {} + for layer_idx in range(len(cross_attention_cache.key_cache)): + self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) + + def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return ( + self.self_attention_cache.key_cache[layer_idx], + self.self_attention_cache.value_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + self.cross_attention_cache.value_cache[layer_idx], + ) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.self_attention_cache) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" + legacy_cache = () + if len(self.cross_attention_cache) > 0: + for self_attn, cross_attn in zip( + self.self_attention_cache.to_legacy_cache(), + self.cross_attention_cache.to_legacy_cache(), + ): + legacy_cache += (self_attn + cross_attn,) + else: + legacy_cache = self.self_attention_cache.to_legacy_cache() + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: tuple[tuple[torch.FloatTensor]] | None = None + ) -> "EncoderDecoderCache": + """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" + cache = cls( + self_attention_cache=DynamicCache(), + cross_attention_cache=DynamicCache(), + ) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx][:2] + cache.self_attention_cache.update(key_states, value_states, layer_idx) + if len(past_key_values[layer_idx]) > 2: + key_states, value_states = past_key_values[layer_idx][2:] + cache.cross_attention_cache.update(key_states, value_states, layer_idx) + cache.is_updated[layer_idx] = True + return cache + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` + return self.self_attention_cache.get_seq_length(layer_idx) + + def reset(self): + if hasattr(self.self_attention_cache, "reset"): + self.self_attention_cache.reset() + if hasattr(self.cross_attention_cache, "reset"): + self.cross_attention_cache.reset() + elif not hasattr(self.self_attention_cache, "reset") and not hasattr( + self.cross_attention_cache, "reset" + ): + raise ValueError( + "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " + "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " + f"Got {self.self_attention_cache.__str__()} for the self attention cache and " + f"{self.cross_attention_cache.__str__()} for the cross attention cache." + ) + for layer_idx in self.is_updated: + self.is_updated[layer_idx] = False + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + self.self_attention_cache.reorder_cache(beam_idx) + self.cross_attention_cache.reorder_cache(beam_idx) + + def check_dynamic_cache(self, method: str): + if not ( + isinstance(self.self_attention_cache, DynamicCache) + and isinstance(self.cross_attention_cache, DynamicCache) + ): + raise ValueError( + f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " + f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." + ) + + # TODO(gante, sanchit-gandhi): move following functionality into `.generate` + def crop(self, maximum_length: int): + """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + self.check_dynamic_cache(self.crop.__name__) + self.self_attention_cache.crop(maximum_length) + + def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + self.check_dynamic_cache(self.batch_split.__name__) + self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) + cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) + + out = [] + for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): + out.append(EncoderDecoderCache(self_attn, cross_attn)) + return out + + @classmethod + def from_batch_splits(cls, splits: list["EncoderDecoderCache"]) -> "EncoderDecoderCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + self_attention_cache = DynamicCache() + cross_attention_cache = DynamicCache() + for idx in range(len(splits[0])): + layer_keys = torch.cat( + [current.self_attention_cache.key_cache[idx] for current in splits], dim=0 + ) + layer_values = torch.cat( + [current.self_attention_cache.value_cache[idx] for current in splits], dim=0 + ) + self_attention_cache.update(layer_keys, layer_values, idx) + + layer_keys = torch.cat( + [current.cross_attention_cache.key_cache[idx] for current in splits], dim=0 + ) + layer_values = torch.cat( + [current.cross_attention_cache.value_cache[idx] for current in splits], dim=0 + ) + cross_attention_cache.update(layer_keys, layer_values, idx) + return cls(self_attention_cache, cross_attention_cache) + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_repeat_interleave.__name__) + self.self_attention_cache.batch_repeat_interleave(repeats) + self.cross_attention_cache.batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_select_indices.__name__) + self.self_attention_cache.batch_select_indices(indices) + self.cross_attention_cache.batch_select_indices(indices) + + +class HybridCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention + and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention + and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (torch.dtype, *optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + + # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert + # ALL changes from the PR that commented the line below when reactivating it. + # is_compileable = True + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + super().__init__() + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + self.max_cache_len = max_cache_len + self.max_batch_size = max_batch_size + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // config.num_attention_heads + ) + + self._dtype = dtype + self.num_key_value_heads = ( + config.num_attention_heads + if config.num_key_value_heads is None + else config.num_key_value_heads + ) + + layer_switch = ( + config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 + ) # 2 is for BC + self.is_sliding = torch.tensor( + [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], + dtype=torch.bool, + ) + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + global_cache_shape = ( + self.max_batch_size, + self.num_key_value_heads, + max_cache_len, + self.head_dim, + ) + sliding_cache_shape = ( + self.max_batch_size, + self.num_key_value_heads, + min(config.sliding_window, max_cache_len), + self.head_dim, + ) + device = torch.device(device) if device is not None and isinstance(device, str) else None + for i in range(config.num_hidden_layers): + if layer_device_map is not None: + layer_device = layer_device_map[i] + else: + layer_device = device + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def _sliding_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + if cache_position.shape[0] > max_cache_len: + k_out = key_states[:, :, -max_cache_len:, :] + v_out = value_states[:, :, -max_cache_len:, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) + cache_position = cache_position.clamp(0, max_cache_len - 1) + to_shift = cache_position >= max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % max_cache_len + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + return k_out, v_out + + def _static_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + return k_out, v_out + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if cache_kwargs is None: + cache_kwargs = {} + cache_position = cache_kwargs.get("cache_position") + sliding_window = cache_kwargs.get("sliding_window") + + # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used + # when the cache is initialized in the forward pass (e.g. Gemma2) + if self.key_cache[layer_idx].device != key_states.device: + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) + if self.value_cache[layer_idx].device != value_states.device: + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if sliding_window: + update_fn = self._sliding_update + else: + update_fn = self._static_update + + return update_fn( + cache_position, + layer_idx, + key_states, + value_states, + k_out, + v_out, + k_out.shape[2], + ) + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def get_seq_length(self, layer_idx: int | None = 0): + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx != 0: + raise ValueError( + "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " + "Using the `layer_idx` argument is not supported." + ) + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class HybridChunkedCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention + and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention + and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + + # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert + # ALL changes from the PR that commented the line below when reactivating it. + is_compileable = True + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.bfloat16, + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + super().__init__() + if not hasattr(config, "sliding_window") or config.sliding_window is None: + self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192) + else: + self.sliding_window = config.sliding_window + self.max_cache_len = max_cache_len + self.max_batch_size = max_batch_size + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self._dtype = dtype + + if hasattr(config.get_text_config(), "no_rope_layers"): + self.is_sliding = config.no_rope_layers + else: + layer_switch = getattr(config, "sliding_window_pattern", 2) + self.is_sliding = [ + bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers) + ] + + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + self.cumulative_length = [0 for _ in range(config.num_hidden_layers)] + + def initialise_cache_layer(self, layer_idx, key_states): + if len(self.key_cache) > layer_idx: + return + + num_key_value_heads = key_states.shape[1] + device = key_states.device + global_cache_shape = ( + self.max_batch_size, + num_key_value_heads, + self.max_cache_len, + self.head_dim, + ) + sliding_cache_shape = ( + self.max_batch_size, + num_key_value_heads, + self.sliding_window, + self.head_dim, + ) + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def _sliding_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + cumulative_length = self.cumulative_length[layer_idx] + # Update it now that we saved the value above + self.cumulative_length[layer_idx] += key_states.shape[-2] + is_full = cumulative_length >= max_cache_len + if is_full: + full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2) + full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2) + # Fast decoding path -> here as the effective size is still sliding window, it is extremely important + # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed adress + # in memory (the values are the same as the full states, but not the address!!) + if key_states.shape[-2] == 1: + self.key_cache[layer_idx].copy_(full_key_states) + self.value_cache[layer_idx].copy_(full_value_states) + return self.key_cache[layer_idx], self.value_cache[layer_idx] + elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: + # Fast prefill path, no need to cat() in this case (which creates a copy even if cating from 0 dim) + if cumulative_length == 0: + full_key_states = key_states + full_value_states = value_states + else: + full_key_states = torch.cat( + (k_out[:, :, :cumulative_length, :], key_states), dim=-2 + ) + full_value_states = torch.cat( + (v_out[:, :, :cumulative_length, :], value_states), dim=-2 + ) + else: + self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) + self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return full_key_states, full_value_states + + def _static_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + return k_out, v_out + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if cache_kwargs is None: + cache_kwargs = {} + cache_position = cache_kwargs.get("cache_position") + self.initialise_cache_layer(layer_idx, key_states) + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if self.is_sliding[layer_idx]: + update_fn = self._sliding_update + else: + update_fn = self._static_update + + return update_fn( + cache_position, + layer_idx, + key_states, + value_states, + k_out, + v_out, + k_out.shape[2], + ) + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def get_seq_length(self, layer_idx: int | None = 0): + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx != 0: + raise ValueError( + "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " + "Using the `layer_idx` argument is not supported." + ) + if len(self.key_cache) == 0: + return 0 + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] + + +class MambaCache: + """ + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + MambaCache() + ``` + """ + + is_compileable = True + + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: torch.device | str | None = None, + ): + self.max_batch_size = max_batch_size + self._dtype = dtype + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: list[torch.Tensor] = [] + self.ssm_states: list[torch.Tensor] = [] + device = torch.device(device) if device is not None else None + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=self._dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=self._dtype, + ) + + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. Mamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) + + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to( + device=conv_state.device, dtype=conv_state.dtype + ) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) + return self.ssm_states[layer_idx] + + def reset(self): + for layer_idx in range(len(self.conv_states)): + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() + + +class OffloadedStaticCache(StaticCache): + """ + Static cache class to be used with `torch.compile(model)` that offloads to the CPU or + another device. + + Args: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize + the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`Union[str, torch.device]`): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (`torch.dtype`, *optional*): + The default `dtype` to use when initializing the cache. + offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): + The device to offload to. Defaults to CPU. + layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is splitted between differents gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + is_compileable = True + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int | None, + device: str | torch.device, + dtype: torch.dtype | None = None, + offload_device: str | torch.device = torch.device("cpu"), + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + super(Cache, self).__init__() + self.max_batch_size = max_batch_size + self.max_cache_len = ( + config.max_position_embeddings if max_cache_len is None else max_cache_len + ) + self.device = ( + torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) + ) + self.offload_device = torch.device(offload_device) + self._dtype = dtype if dtype is not None else torch.float32 + + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // config.num_attention_heads + ) + + num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + + cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) + + # Create offloaded CPU tensors. + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + for i in range(config.num_hidden_layers): + # First layer is always on-device. + device = self.device if i == 0 else self.offload_device + + key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device) + + self.key_cache.append(key_cache) + self.value_cache.append(value_cache) + + # Create device tensors. + self._device_key_cache: list[torch.Tensor] = [] + self._device_value_cache: list[torch.Tensor] = [] + + for i in range(2): + key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device) + + self._device_key_cache.append(key_cache) + self._device_value_cache.append(value_cache) + + # For backwards compatibility. + # TODO(gante): Remove this. + self._seen_tokens = 0 + + # Create new CUDA stream for parallel prefetching. + self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the + `cache_position` input to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + if layer_idx == 0: + # Update seen tokens. + # TODO(gante): Remove this. + self._seen_tokens += key_states.shape[-2] + + # Always there. + k_out = self.key_cache[0] + v_out = self.value_cache[0] + else: + # Wait for prefetch stream. + if self._prefetch_stream is not None: + torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream) + + k_out = self._device_key_cache[layer_idx & 1] + v_out = self._device_value_cache[layer_idx & 1] + + self._prefetch_layer(layer_idx + 1) + + cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + + # Copy the values to the offloaded device as well. + if layer_idx == 0: + self.key_cache[layer_idx].copy_(key_states.to(self.offload_device)) + self.value_cache[layer_idx].copy_(value_states.to(self.offload_device)) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does + # explicitly an in-place operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS + # device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + # Copy the values to the offloaded device as well. + if layer_idx != 0: + cache_position = cache_position.to(self.offload_device) + key_states = key_states.to(self.offload_device) + value_states = value_states.to(self.offload_device) + + try: + self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS + # device. + self.key_cache[layer_idx][:, :, cache_position] = key_states + self.value_cache[layer_idx][:, :, cache_position] = value_states + + return k_out, v_out + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + + # TODO(gante): Remove this. + return self._seen_tokens + + def get_max_cache_shape(self) -> int | None: + """Returns the maximum sequence length of the cached states.""" + + return self.max_cache_len + + def reset(self) -> None: + """Resets the cache values while preserving the objects.""" + + # For backwards compatibility. + # TODO(gante): Remove this. + self._seen_tokens = 0 + + # Zero out cache. + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address. + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + @property + def seen_tokens(self) -> int: + # For backwards compatibility. + # TODO(gante): Remove this. + return self._seen_tokens + + def _create_key_value_cache_tensors( + self, shape: tuple[int, ...], device: torch.device + ) -> tuple[torch.Tensor, torch.Tensor]: + """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static + addresses for non-CPU tensors. + + Args: + shape (`Tuple[int, ...]`): Shape. + device (`torch.device`): Device. + + Returns: + Key and value cache tensors as a tuple. + """ + + is_cpu_device = device == torch.device("cpu") + + key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) + value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) + + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(key_cache) + torch._dynamo.mark_static_address(value_cache) + + return key_cache, value_cache + + def _prefetch_layer(self, layer_idx: int) -> None: + """Prefetch a layer to the device. Needs to be called in order of layer indices.""" + + # Don't fetch layers that do not exist. + if layer_idx >= len(self.key_cache): + return + + # Alternate between two on-device caches. + if self._prefetch_stream is not None: + with torch.cuda.stream(self._prefetch_stream): + self._prefetch_layer_in_context(layer_idx) + else: + self._prefetch_layer_in_context(layer_idx) + + def _prefetch_layer_in_context(self, layer_idx: int) -> None: + """Performs the actual copy of the layer to device cache.""" + + self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) + self._device_value_cache[layer_idx & 1].copy_( + self.value_cache[layer_idx], non_blocking=True + ) From 991659f37b6a038befde8fc8ac5f69fbf64ad086 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 3 Nov 2025 21:18:13 +0100 Subject: [PATCH 02/49] Add decilm modelling code. Signed-off-by: Daniel Korzekwa --- .pre-commit-config.yaml | 2 + .../transformers_4_44_2__cache_utils.py | 1447 +++++++++++++++++ .../transformers_4_51_3__cache_utils.py | 1 - 3 files changed, 1449 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4cf8ddd34..fe2c3fd95 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,12 +28,14 @@ repos: exclude: > (?x)^( modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config\.py| + modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils\.py| modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils\.py )$ - id: ruff-format exclude: > (?x)^( modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config\.py| + modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils\.py| modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils\.py )$ diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py new file mode 100644 index 000000000..83d7251dd --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils.py @@ -0,0 +1,1447 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +import copy +import importlib.metadata +import json +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from packaging import version + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torchdynamo_compiling, logging + + +logger = logging.get_logger(__name__) + + +class Cache(torch.nn.Module): + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + def __init__(self): + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states, if there is any.""" + raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_length() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) + + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " + "model input instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + + +@dataclass +class CacheConfig: + """ + Base class for cache configs + """ + + cache_implementation: None + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a CacheConfig instance from a dictionary of parameters. + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + + Returns: + CacheConfig: Instance of CacheConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +class DynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__(self) -> None: + super().__init__() + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self._seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def crop(self, max_length: int): + """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" + # In case it is negative + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls() + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave( + repeats, dim=0 + ) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + + +class OffloadedCache(DynamicCache): + """ + A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. + Useful for generating from models with very long context. + + In addition to the default CUDA stream, where all forward() computations happen, + this class uses another stream, the prefetch stream, which it creates itself. + Since scheduling of operations on separate streams happens independently, this class uses + the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. + The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to + ensure the eviction is scheduled after all computations on that cache are finished. + """ + + def __init__(self) -> None: + if not torch.cuda.is_available(): + raise RuntimeError("OffloadedCache can only be used with a GPU") + super().__init__() + self.original_device = [] + self.prefetch_stream = torch.cuda.Stream() + self.beam_idx = None # used to delay beam search operations + + def prefetch_layer(self, layer_idx: int): + "Starts prefetching the next layer cache" + if layer_idx < len(self): + with torch.cuda.stream(self.prefetch_stream): + # Prefetch next layer tensors to GPU + device = self.original_device[layer_idx] + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to( + device, non_blocking=True + ) + + def evict_previous_layer(self, layer_idx: int): + "Moves the previous layer cache to the CPU" + if len(self) > 2: + # We do it on the default stream so it occurs after all earlier computations on these tensors are done + prev_layer_idx = (layer_idx - 1) % len(self) + self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to( + "cpu", non_blocking=True + ) + self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to( + "cpu", non_blocking=True + ) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." + if layer_idx < len(self): + # Evict the previous layer if necessary + torch.cuda.current_stream().synchronize() + self.evict_previous_layer(layer_idx) + # Load current layer cache to its original device if not already there + original_device = self.original_device[layer_idx] + self.prefetch_stream.synchronize() + key_tensor = self.key_cache[layer_idx] + value_tensor = self.value_cache[layer_idx] + # Now deal with beam search ops which were delayed + if self.beam_idx is not None: + self.beam_idx = self.beam_idx.to(original_device) + key_tensor = key_tensor.index_select(0, self.beam_idx) + value_tensor = value_tensor.index_select(0, self.beam_idx) + # Prefetch the next layer + self.prefetch_layer((layer_idx + 1) % len(self)) + return (key_tensor, value_tensor) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Saves the beam indices and reorders the cache when the tensor is back to its device.""" + # We delay this operation until the tensors are back to their original + # device because performing torch.index_select on the CPU is very slow + del self.beam_idx + self.beam_idx = beam_idx.clone() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.original_device.append(key_states.device) + self.evict_previous_layer(layer_idx) + else: + key_tensor, value_tensor = self[layer_idx] + self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError + # if a method is not supposed to be supported in a subclass we should set it to None + from_legacy_cache = None + + to_legacy_cache = None + + +class SinkCache(Cache): + """ + A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to + generate beyond the length of its context window, without losing fluency in the conversation. As it discards past + tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Parameters: + window_length (`int`): + The length of the context window. + num_sink_tokens (`int`): + The number of sink tokens. See the original paper for more information. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__(self, window_length: int, num_sink_tokens: int) -> None: + super().__init__() + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.window_length = window_length + self.num_sink_tokens = num_sink_tokens + self.cos_sin_rerotation_cache = {} + self._cos_cache = None + self._sin_cache = None + self._seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + + @staticmethod + def _rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_key_rotary_pos_emb( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> torch.Tensor: + rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) + return rotated_key_states + + def _get_rerotation_cos_sin( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if key_states.shape[-2] not in self.cos_sin_rerotation_cache: + # Upcast to float32 temporarily for better accuracy + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence + original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] + shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] + original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] + shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] + rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin + rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin + + self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( + rerotation_cos.to(key_states.dtype).unsqueeze(0), + rerotation_sin.to(key_states.dtype).unsqueeze(0), + ) + return self.cos_sin_rerotation_cache[key_states.shape[-2]] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + return self.window_length + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, + `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the + rotation as the tokens are shifted. + + Return: + A tuple containing the updated key and value states. + """ + # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models + # with partially rotated position embeddings, like Phi or Persimmon. + sin = cache_kwargs.get("sin") + cos = cache_kwargs.get("cos") + partial_rotation_size = cache_kwargs.get("partial_rotation_size") + using_rope = cos is not None and sin is not None + + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the sin/cos cache, which holds sin/cos values for all possible positions + if using_rope and layer_idx == 0: + # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove + # after all RoPE models have a llama-like cache utilization. + if cos.dim() == 2: + self._cos_cache = cos + self._sin_cache = sin + else: + if self._cos_cache is None: + self._cos_cache = cos[0, ...] + self._sin_cache = sin[0, ...] + elif self._cos_cache.shape[0] < self.window_length: + self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) + self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) + + # [bsz, num_heads, seq_len, head_dim] + if len(self.key_cache) <= layer_idx: + # Empty cache + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: + # Growing cache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + else: + # Shifting cache + keys_to_keep = self.key_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : + ] + + # On RoPE models, we need to recompute the Key rotation as the tokens are shifted + if using_rope: + rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( + key_states, + self._cos_cache[: self.window_length], + self._sin_cache[: self.window_length], + ) + if partial_rotation_size is not None: + keys_to_keep, keys_pass = ( + keys_to_keep[..., :partial_rotation_size], + keys_to_keep[..., partial_rotation_size:], + ) + keys_to_keep = self._apply_key_rotary_pos_emb( + keys_to_keep, rerotation_cos, rerotation_sin + ) + if partial_rotation_size is not None: + keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) + + # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens + sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] + self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) + + sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] + values_to_keep = self.value_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : + ] + self.value_cache[layer_idx] = torch.cat( + [sink_values, values_to_keep, value_states], dim=-2 + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__( + self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None + ) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.max_cache_len = ( + config.max_position_embeddings if max_cache_len is None else max_cache_len + ) + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // config.num_attention_heads + ) + + self.dtype = dtype if dtype is not None else torch.float32 + self.num_key_value_heads = ( + config.num_attention_heads + if config.num_key_value_heads is None + else config.num_key_value_heads + ) + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + # Note: There will be significant perf decrease if switching to use 5D tensors instead. + cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) + for idx in range(config.num_hidden_layers): + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + # Notes: + # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case + # it is not needed anyway) + # 2. `torch.export()` requires mutations to be registered as buffers. + if not is_torchdynamo_compiling(): + self.register_buffer( + f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device) + ) + self.register_buffer( + f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device) + ) + new_layer_key_cache = getattr(self, f"key_cache_{idx}") + new_layer_value_cache = getattr(self, f"value_cache_{idx}") + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + self._seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + cache_position = cache_kwargs.get("cache_position") + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + # return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + return self._seen_tokens + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + return self.max_cache_len + + def reset(self): + self._seen_tokens = 0 + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class SlidingWindowCache(StaticCache): + """ + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, + if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), + we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. + + The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + + indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + + We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__( + self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None + ) -> None: + super().__init__(config, max_batch_size, max_cache_len, device, dtype) + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + max_cache_len = min(config.sliding_window, max_cache_len) + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + + # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) + if cache_position.shape[0] > self.max_cache_len: + k_out = key_states[:, :, -self.max_cache_len :, :] + v_out = value_states[:, :, -self.max_cache_len :, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones( + self.max_cache_len, dtype=torch.long, device=value_states.device + ).cumsum(0) + cache_position = cache_position.clamp(0, self.max_cache_len - 1) + to_shift = cache_position >= self.max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len + + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + try: + cache_position.to(device=k_out.device) + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + + return k_out, v_out + + def get_max_length(self) -> Optional[int]: + # in theory there is no limit because the sliding window size is fixed no matter how long the sentence is + return None + + def reset(self): + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class EncoderDecoderCache(Cache): + """ + Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and + cross-attention caches. + + Example: + + ```python + >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") + >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") + + >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") + + >>> # Prepare cache classes for encoder and decoder and pass it to model's forward + >>> self_attention_cache = DynamicCache() + >>> cross_attention_cache = DynamicCache() + >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + + """ + + def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): + super().__init__() + self.self_attention_cache = self_attention_cache + self.cross_attention_cache = cross_attention_cache + + self.is_updated = {} + for layer_idx in range(len(cross_attention_cache.key_cache)): + self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return ( + self.self_attention_cache.key_cache[layer_idx], + self.self_attention_cache.value_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + self.cross_attention_cache.value_cache[layer_idx], + ) + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.self_attention_cache) + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" + legacy_cache = () + if len(self.cross_attention_cache) > 0: + for self_attn, cross_attn in zip( + self.self_attention_cache.to_legacy_cache(), + self.cross_attention_cache.to_legacy_cache(), + ): + legacy_cache += (self_attn + cross_attn,) + else: + legacy_cache = self.self_attention_cache.to_legacy_cache() + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "EncoderDecoderCache": + """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" + cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx][:2] + cache.self_attention_cache.update(key_states, value_states, layer_idx) + if len(past_key_values[layer_idx]) > 2: + key_states, value_states = past_key_values[layer_idx][2:] + cache.cross_attention_cache.update(key_states, value_states, layer_idx) + cache.is_updated[layer_idx] = True + return cache + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.self_attention_cache.key_cache) <= layer_idx: + return 0 + return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def reset(self): + if hasattr(self.self_attention_cache, "reset"): + self.self_attention_cache.reset() + if hasattr(self.cross_attention_cache, "reset"): + self.cross_attention_cache.reset() + elif not hasattr(self.self_attention_cache, "reset") and not hasattr( + self.cross_attention_cache, "reset" + ): + raise ValueError( + "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " + "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " + f"Got {self.self_attention_cache.__str__()} for the self attention cache and " + f"{self.cross_attention_cache.__str__()} for the cross attention cache." + ) + for layer_idx in self.is_updated: + self.is_updated[layer_idx] = False + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + self.self_attention_cache.reorder_cache(beam_idx) + self.cross_attention_cache.reorder_cache(beam_idx) + + def check_dynamic_cache(self, method: str): + if not ( + isinstance(self.self_attention_cache, DynamicCache) + and isinstance(self.cross_attention_cache, DynamicCache) + ): + raise ValueError( + f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " + f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." + ) + + # TODO(gante, sanchit-gandhi): move following functionality into `.generate` + def crop(self, maximum_length: int): + """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + self.check_dynamic_cache(self.crop.__name__) + self.self_attention_cache.crop(maximum_length) + + def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + self.check_dynamic_cache(self.batch_split.__name__) + self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) + cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) + + out = [] + for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): + out.append(EncoderDecoderCache(self_attn, cross_attn)) + return out + + @classmethod + def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + self_attention_cache = DynamicCache() + cross_attention_cache = DynamicCache() + for idx in range(len(splits[0])): + layer_keys = torch.cat( + [current.self_attention_cache.key_cache[idx] for current in splits], dim=0 + ) + layer_values = torch.cat( + [current.self_attention_cache.value_cache[idx] for current in splits], dim=0 + ) + self_attention_cache.update(layer_keys, layer_values, idx) + + layer_keys = torch.cat( + [current.cross_attention_cache.key_cache[idx] for current in splits], dim=0 + ) + layer_values = torch.cat( + [current.cross_attention_cache.value_cache[idx] for current in splits], dim=0 + ) + cross_attention_cache.update(layer_keys, layer_values, idx) + return cls(self_attention_cache, cross_attention_cache) + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_repeat_interleave.__name__) + self.self_attention_cache.batch_repeat_interleave(repeats) + self.cross_attention_cache.batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_select_indices.__name__) + self.self_attention_cache.batch_select_indices(indices) + self.cross_attention_cache.batch_select_indices(indices) + + +class HybridCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention + and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention + and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`, *optional*, defaults to `"cpu"`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__( + self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None + ) -> None: + super().__init__() + if not hasattr(config, "sliding_window") or config.sliding_window is None: + raise ValueError( + "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + self.max_cache_len = max_cache_len + self.max_batch_size = max_batch_size + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // config.num_attention_heads + ) + + self.dtype = dtype if dtype is not None else torch.float32 + self.num_key_value_heads = ( + config.num_attention_heads + if config.num_key_value_heads is None + else config.num_key_value_heads + ) + self.is_sliding = torch.tensor( + [not bool(i % 2) for i in range(config.num_hidden_layers)], + dtype=torch.bool, + device=device, + ) + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + global_cache_shape = ( + max_batch_size, + self.num_key_value_heads, + max_cache_len, + self.head_dim, + ) + sliding_cache_shape = ( + max_batch_size, + self.num_key_value_heads, + min(config.sliding_window, max_cache_len), + self.head_dim, + ) + for i in range(config.num_hidden_layers): + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def _sliding_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + if cache_position.shape[0] > max_cache_len: + k_out = key_states[:, :, -max_cache_len:, :] + v_out = value_states[:, :, -max_cache_len:, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) + cache_position = cache_position.clamp(0, max_cache_len - 1) + to_shift = cache_position >= max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % max_cache_len + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + return k_out, v_out + + def _static_update( + self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len + ): + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + return k_out, v_out + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") + sliding_window = cache_kwargs.get("sliding_window") + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + if sliding_window: + update_fn = self._sliding_update + else: + update_fn = self._static_update + + return update_fn( + cache_position, + layer_idx, + key_states, + value_states, + k_out, + v_out, + k_out.shape[2], + ) + + def get_max_length(self) -> Optional[int]: + # in theory there is no limit because the sliding window size is fixed + # no matter how long the sentence is + return self.max_cache_len + + def get_seq_length(self, layer_idx: Optional[int] = 0): + return None + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + +class MambaCache: + """ + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + dtype (*optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Attributes: + dtype: (`torch.dtype`): + The default `dtype` used to initializing the cache. + intermediate_size: (`int`): + Model's intermediate_size taken from config. + ssm_state_size: (`int`): + Model's state_size taken from config. + conv_kernel_size: (`int`): + Model's convolution kernel size taken from config + conv_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states. + ssm_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv = outputs.past_key_values + ``` + """ + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + dtype: torch.dtype = torch.float16, + device: Optional[str] = None, + **kwargs, + ): + self.dtype = dtype + self.max_batch_size = max_batch_size + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: torch.Tensor = torch.zeros( + config.num_hidden_layers, + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + self.ssm_states: torch.Tensor = torch.zeros( + config.num_hidden_layers, + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=dtype, + ) + + torch._dynamo.mark_static_address(self.conv_states) + torch._dynamo.mark_static_address(self.ssm_states) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py index e872a87d2..ebcebdebe 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# ruff: noqa # mypy: ignore-errors import copy import importlib.metadata From 8489cee55dcebab303c0cd4f436e511cd8c6b4a6 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 3 Nov 2025 21:38:55 +0100 Subject: [PATCH 03/49] Add transformers codebase Signed-off-by: Daniel Korzekwa --- .pre-commit-config.yaml | 6 +- .../transformers_4_44_2__activations.py | 254 +++++++++ ...ransformers_4_44_2__configuration_llama.py | 218 ++++++++ ...ormers_4_44_2__modeling_attn_mask_utils.py | 497 ++++++++++++++++++ 4 files changed, 971 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe2c3fd95..8f432dc98 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,15 +28,13 @@ repos: exclude: > (?x)^( modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config\.py| - modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils\.py| - modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils\.py + modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_.*\.py )$ - id: ruff-format exclude: > (?x)^( modelopt/torch/_compress/decilm/deci_lm_hf_code/block_config\.py| - modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__cache_utils\.py| - modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils\.py + modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_.*\.py )$ - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py new file mode 100644 index 000000000..8b4810c5d --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import OrderedDict + +import torch +from packaging import version +from torch import Tensor, nn + +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class PytorchGELUTanh(nn.Module): + """ + A fast C implementation of the tanh approximation of the GeLU activation function. See + https://arxiv.org/abs/1606.08415. + + This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical + match due to rounding errors. + """ + + def __init__(self): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.12.0"): + raise ImportError( + f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use " + "PytorchGELUTanh. Please upgrade torch." + ) + + def forward(self, input: Tensor) -> Tensor: + return nn.functional.gelu(input, approximate="tanh") + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + +class GELUActivation(nn.Module): + """ + Original Implementation of the GELU activation function in Google BERT repo when initially created. For + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional + Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, use_gelu_python: bool = False): + super().__init__() + if use_gelu_python: + self.act = self._gelu_python + else: + self.act = nn.functional.gelu + + def _gelu_python(self, input: Tensor) -> Tensor: + return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class FastGELUActivation(nn.Module): + """ + Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) + + +class QuickGELUActivation(nn.Module): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + return input * torch.sigmoid(1.702 * input) + + +class ClippedGELUActivation(nn.Module): + """ + Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as + it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to + https://arxiv.org/abs/2004.09602. + + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. + + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, min: float, max: float): + if min > max: + raise ValueError(f"min should be < max (got min: {min}, max: {max})") + + super().__init__() + self.min = min + self.max = max + + def forward(self, x: Tensor) -> Tensor: + return torch.clip(gelu(x), self.min, self.max) + + +class AccurateGELUActivation(nn.Module): + """ + Applies GELU approximation that is faster than default and more accurate than QuickGELU. See: + https://github.com/hendrycks/GELUs + + Implemented along with MEGA (Moving Average Equipped Gated Attention) + """ + + def __init__(self): + super().__init__() + self.precomputed_constant = math.sqrt(2 / math.pi) + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3)))) + + +class MishActivation(nn.Module): + """ + See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also + visit the official repository for the paper: https://github.com/digantamisra98/Mish + """ + + def __init__(self): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.9.0"): + self.act = self._mish_python + else: + self.act = nn.functional.mish + + def _mish_python(self, input: Tensor) -> Tensor: + return input * torch.tanh(nn.functional.softplus(input)) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class LinearActivation(nn.Module): + """ + Applies the linear activation function, i.e. forwarding input directly to output. + """ + + def forward(self, input: Tensor) -> Tensor: + return input + + +class LaplaceActivation(nn.Module): + """ + Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See + https://arxiv.org/abs/2209.10655 + + Inspired by squared relu, but with bounded range and gradient for better stability + """ + + def forward(self, input, mu=0.707107, sigma=0.282095): + input = (input - mu).div(sigma * math.sqrt(2.0)) + return 0.5 * (1.0 + torch.erf(input)) + + +class ReLUSquaredActivation(nn.Module): + """ + Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 + """ + + def forward(self, input): + relu_applied = nn.functional.relu(input) + squared = torch.square(relu_applied) + return squared + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "gelu": GELUActivation, + "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), + "gelu_fast": FastGELUActivation, + "gelu_new": NewGELUActivation, + "gelu_python": (GELUActivation, {"use_gelu_python": True}), + "gelu_pytorch_tanh": PytorchGELUTanh, + "gelu_accurate": AccurateGELUActivation, + "laplace": LaplaceActivation, + "leaky_relu": nn.LeakyReLU, + "linear": LinearActivation, + "mish": MishActivation, + "quick_gelu": QuickGELUActivation, + "relu": nn.ReLU, + "relu2": ReLUSquaredActivation, + "relu6": nn.ReLU6, + "sigmoid": nn.Sigmoid, + "silu": nn.SiLU, + "swish": nn.SiLU, + "tanh": nn.Tanh, +} +ACT2FN = ClassInstantier(ACT2CLS) + + +def get_activation(activation_string): + if activation_string in ACT2FN: + return ACT2FN[activation_string] + else: + raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") + + +# For backwards compatibility with: from activations import gelu_python +gelu_python = get_activation("gelu_python") +gelu_new = get_activation("gelu_new") +gelu = get_activation("gelu") +gelu_fast = get_activation("gelu_fast") +quick_gelu = get_activation("quick_gelu") +silu = get_activation("silu") +mish = get_activation("mish") +linear_act = get_activation("linear") diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py new file mode 100644 index 000000000..0098de411 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LLaMA model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from .transformers_4_44_2__modeling_rope_utils import rope_config_validation + + +class LlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py new file mode 100644 index 000000000..76631f758 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py @@ -0,0 +1,497 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + + +@dataclass +class AttentionMaskConverter: + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Examples: + + ```python + >>> import torch + >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + >>> converter = AttentionMaskConverter(True) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) + tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]]) + ``` + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + is_causal: bool + sliding_window: int + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError( + f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" + ) + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype, + device: Union[torch.device, "str"] = "cpu", + ) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + dtype: torch.dtype, + key_value_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." + ) + + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + + if causal_4d_mask is not None: + expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) + + # expanded_attn_mask + causal_4d_mask can cause some overflow + expanded_4d_mask = expanded_attn_mask + + return expanded_4d_mask + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window - 1 + + context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) + mask.masked_fill_(context_mask, torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + @staticmethod + def _unmask_unattended( + expanded_mask: torch.FloatTensor, + min_dtype: float, + ): + # fmt: off + """ + Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when + using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + Details: https://github.com/pytorch/pytorch/issues/110213 + + `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. + `attention_mask` is [bsz, src_seq_len]. + + The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. + + For example, if `expanded_mask` is (e.g. here left-padding case) + ``` + [[[[0, 0, 0], + [0, 0, 0], + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[0, 0, 0], + [0, 1, 0], + [0, 1, 1]]]] + ``` + then the modified `expanded_mask` will be + ``` + [[[[1, 1, 1], <-- modified + [1, 1, 1], <-- modified + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[1, 1, 1], <-- modified + [0, 1, 0], + [0, 1, 1]]]] + ``` + """ + # fmt: on + if expanded_mask.dtype == torch.bool: + raise ValueError( + "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." + ) + + return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) + + @staticmethod + def _ignore_causal_mask_sdpa( + attention_mask: Optional[torch.Tensor], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, + is_training: bool = False, + ) -> bool: + """ + Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. + + In case no token is masked in the `attention_mask` argument, if `query_length == 1` or + `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + + _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] + key_value_length = query_length + past_key_values_length + + is_tracing = ( + torch.jit.is_tracing() + or isinstance(inputs_embeds, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + + ignore_causal_mask = False + + if attention_mask is None: + # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or + # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). + # Thus, we only set `ignore_causal_mask = True` if the model is set to training. + # + # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor"). + if ( + (is_training or not is_tracing) + and (query_length == 1 or key_value_length == query_length) + and (sliding_window is None or key_value_length < sliding_window) + ): + ignore_causal_mask = True + elif sliding_window is None or key_value_length < sliding_window: + if len(attention_mask.shape) == 4: + return False + elif (is_training or not is_tracing) and torch.all(attention_mask == 1): + if query_length == 1 or key_value_length == query_length: + # For query_length == 1, causal attention and bi-directional attention are the same. + ignore_causal_mask = True + + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. + + return ignore_causal_mask + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None and len(attention_mask.shape) == 2: + attention_mask = attn_mask_converter.to_4d( + attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype + ) + elif attention_mask is not None and len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + + return attention_mask + + +# Adapted from _prepare_4d_causal_attention_mask +def _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. + + In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and + `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(inputs_embeds, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + + ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) + + if ignore_causal_mask: + expanded_4d_mask = None + elif attention_mask is None: + expanded_4d_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + else: + if attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + expanded_4d_mask = attention_mask + else: + expanded_4d_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + dtype=inputs_embeds.dtype, + key_value_length=key_value_length, + ) + + # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + if not is_tracing and expanded_4d_mask.device.type == "cuda": + expanded_4d_mask = AttentionMaskConverter._unmask_unattended( + expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min + ) + + return expanded_4d_mask + + +def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + _, key_value_length = mask.shape + tgt_len = tgt_len if tgt_len is not None else key_value_length + + is_tracing = ( + torch.jit.is_tracing() + or isinstance(mask, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows. + if not is_tracing and torch.all(mask == 1): + return None + else: + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _create_4d_causal_attention_mask( + input_shape: Union[torch.Size, Tuple, List], + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` + + Args: + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + device (`int`): + The torch device the created mask shall have. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = past_key_values_length + input_shape[-1] + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device + ) + + return attention_mask From f0afefe8fcd5aae7c52499d8efb696e6e7efecd7 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 3 Nov 2025 21:42:34 +0100 Subject: [PATCH 04/49] Add transformers code Signed-off-by: Daniel Korzekwa --- ...g_flash_attention_utils_backward_compat.py | 363 ++++ .../transformers_4_44_2__modeling_outputs.py | 1768 +++++++++++++++++ ...ransformers_4_44_2__modeling_rope_utils.py | 574 ++++++ .../transformers_4_44_2__pytorch_utils.py | 32 + ...ansformers_4_51_3__configuration_llama4.py | 448 +++++ ...rmers_4_51_3__modeling_llama4_attention.py | 289 +++ 6 files changed, 3474 insertions(+) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py new file mode 100644 index 000000000..245184a1c --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py @@ -0,0 +1,363 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# coding=utf-8 +# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import inspect +import os +from typing import Optional, Tuple, Union + + +import torch +import torch.nn.functional as F + +from functools import lru_cache +import importlib.metadata +import importlib.util +from packaging import version + +from transformers.utils import is_flash_attn_2_available + + +if is_flash_attn_2_available(): + try: + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn import flash_attn_func, flash_attn_varlen_func + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + except ImportError: + raise "Unable to import flash_attn" + + +def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: + # Check if the package spec exists and grab its version to avoid importing a local directory + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + # Primary method to get the package version + package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + # Fallback method: Only for "torch" and versions containing "dev" + if pkg_name == "torch": + try: + package = importlib.import_module(pkg_name) + temp_version = getattr(package, "__version__", "N/A") + # Check if the version contains "dev" + if "dev" in temp_version: + package_version = temp_version + package_exists = True + else: + package_exists = False + except ImportError: + # If the package can't be imported, it's not available + package_exists = False + else: + # For packages other than "torch", don't attempt the fallback and set as not available + package_exists = False + if return_version: + return package_exists, package_version + else: + return package_exists + + +@lru_cache() +def is_flash_attn_greater_or_equal(library_version: str): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) + + +def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def prepare_fa2_from_position_ids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cummulative lengths of each examples in the batch will be extracted from position_ids. + + NOTE: ideally cummulative lengths should be prepared at the data collator stage + + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + + cu_seq_lens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + + max_length = position_ids.max() + 1 + + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + + +def _flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: bool = None, +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_top_left_mask (`bool`, defaults to `False`): + flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. + softcap (`float`, *optional*): + Softcap for the attention logits, used e.g. in gemma2. + deterministic (`bool`, *optional*): + Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. + """ + if not use_top_left_mask: + causal = is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. + causal = is_causal and query_length != 1 + + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + use_sliding_windows = ( + _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + ) + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + + if is_flash_attn_greater_or_equal("2.4.1"): + if deterministic is None: + deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = deterministic + + if softcap is not None: + flash_kwargs["softcap"] = softcap + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + + # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): + batch_size = query_states.size(0) + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs + ) + + return attn_output diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py new file mode 100644 index 000000000..dddf0c9e6 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py @@ -0,0 +1,1768 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + +from transformers.utils import ModelOutput + + +@dataclass +class BaseModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithNoAttention(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPooling(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndNoAttention(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a pooling operation on the spatial dimensions. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithCrossAttentions(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MoECausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden + states terms, to train a MoE model. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + z_loss for the sparse modules. + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + z_loss: torch.FloatTensor = None + aux_loss: torch.FloatTensor = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoEModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss and the z_loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_probs: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoeModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoeCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) with mixture of experts outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + aux_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoEModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as + Mixture of Expert's router hidden states terms, to train a MoE model. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss and the z_loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_probs: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class Seq2SeqModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqMoEModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class CausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class CausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class CausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key, + value states of the self-attention and the cross-attention layers if model is used in encoder-decoder + setting. Only relevant if `config.is_decoder = True`. + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SequenceClassifierOutputWithPast(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqMoEOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts + models. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + encoder_z_loss: torch.FloatTensor = None + decoder_z_loss: torch.FloatTensor = None + encoder_aux_loss: torch.FloatTensor = None + decoder_aux_loss: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class NextSentencePredictorOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `next_sentence_label` is provided): + Next sequence prediction (classification) loss. + logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MultipleChoiceModelOutput(ModelOutput): + """ + Base class for outputs of multiple choice models. + + Args: + loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class TokenClassifierOutput(ModelOutput): + """ + Base class for outputs of token classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class QuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of question answering models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: torch.FloatTensor = None + end_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence question answering models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: torch.FloatTensor = None + end_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SemanticSegmenterOutput(ModelOutput): + """ + Base class for outputs of semantic segmentation models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageClassifierOutput(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageClassifierOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class DepthEstimatorOutput(ModelOutput): + """ + Base class for outputs of depth estimation models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`): + Predicted depth for each pixel. + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + predicted_depth: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageSuperResolutionOutput(ModelOutput): + """ + Base class for outputs of image super resolution models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Reconstruction loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed images, possibly upscaled. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Wav2Vec2BaseModelOutput(ModelOutput): + """ + Base class for models that have been trained with the Wav2Vec2 loss objective. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`): + Sequence of extracted feature vectors of the last convolutional layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + extract_features: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class XVectorOutput(ModelOutput): + """ + Output type of [`Wav2Vec2ForXVector`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): + Classification hidden states before AMSoftmax. + embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): + Utterance embeddings used for vector similarity-based retrieval. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + embeddings: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BackboneOutput(ModelOutput): + """ + Base class for outputs of backbones. + + Args: + feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`): + Feature maps of the stages. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`, + depending on the backbone. + + Hidden-states of the model at the output of each stage plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Only applicable if the backbone uses attention. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + feature_maps: Tuple[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndProjection(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` of shape `(batch_size,config.project_dim)`. + + Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + projection_state: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class Seq2SeqSpectrogramOutput(ModelOutput): + """ + Base class for sequence-to-sequence spectrogram outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Spectrogram generation loss. + spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`): + The predicted spectrogram. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + spectrogram: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqTSModelOutput(ModelOutput): + """ + Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up + sequential decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Shift values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to shift back to the original magnitude. + scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Scaling values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to rescale back to the original magnitude. + static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): + Static features of each time series' in a batch which are copied to the covariates at inference time. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + static_features: Optional[torch.FloatTensor] = None + + +@dataclass +class Seq2SeqTSPredictionOutput(ModelOutput): + """ + Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the + chosen distribution. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when a `future_values` is provided): + Distributional loss. + params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`): + Parameters of the chosen distribution. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Shift values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to shift back to the original magnitude. + scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Scaling values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to rescale back to the original magnitude. + static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): + Static features of each time series' in a batch which are copied to the covariates at inference time. + """ + + loss: Optional[torch.FloatTensor] = None + params: Optional[Tuple[torch.FloatTensor]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + static_features: Optional[torch.FloatTensor] = None + + +@dataclass +class SampleTSPredictionOutput(ModelOutput): + """ + Base class for time series model's predictions outputs that contains the sampled values from the chosen + distribution. + + Args: + sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`): + Sampled values from the chosen distribution. + """ + + sequences: torch.FloatTensor = None + + +@dataclass +class MaskedImageModelingOutput(ModelOutput): + """ + Base class for outputs of masked image completion / in-painting models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Reconstruction loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed / completed images. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py new file mode 100644 index 000000000..0a59929aa --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py @@ -0,0 +1,574 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + + +logger = logging.get_logger(__name__) + + +if is_torch_available(): + import torch + + +def _compute_default_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_linear_scaling_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + factor = rope_kwargs["factor"] + elif config is not None: + factor = config.rope_scaling["factor"] + + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + + # Then applies linear scaling to the frequencies. + # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so + # applying scaling to the inverse frequencies is equivalent. + inv_freq /= factor + return inv_freq, attention_factor + + +def _compute_dynamic_ntk_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length, used to update the dynamic RoPE at inference time. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + max_position_embeddings = rope_kwargs["max_position_embeddings"] + factor = rope_kwargs["factor"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] + + attention_factor = 1.0 # Unused in this type of RoPE + + # seq_len: default to max_position_embeddings, e.g. at init time + seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings + + # Compute the inverse frequencies + base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_yarn_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://arxiv.org/abs/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # No need to keep BC with yarn, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] + + # Sets the attention factor as suggested in the paper + attention_factor = config.rope_scaling.get("attention_factor") + if attention_factor is None: + attention_factor = 0.1 * math.log(factor) + 1.0 + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + """Find dimension range bounds based on rotations""" + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. + pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, attention_factor + + +def _compute_longrope_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with LongRoPE scaling. Please refer to the + [original implementation](https://github.com/microsoft/LongRoPE) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling + # No need to keep BC with longrope, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got " + f"{rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + long_factor = config.rope_scaling["long_factor"] + short_factor = config.rope_scaling["short_factor"] + factor = config.rope_scaling.get("factor") + attention_factor = config.rope_scaling.get("attention_factor") + + # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a + # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two + # values to compute the default attention scaling factor, instead of using `factor`. + if hasattr(config, "original_max_position_embeddings"): + max_position_embeddings = config.original_max_position_embeddings + expanded_max_position_embeddings = config.max_position_embeddings + factor = expanded_max_position_embeddings / max_position_embeddings + else: + max_position_embeddings = config.max_position_embeddings + expanded_max_position_embeddings = max_position_embeddings * factor + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if factor <= 1.0: + attention_factor = 1.0 + else: + attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)) + + # Compute the inverse frequencies -- scaled based on the target sequence length + if expanded_max_position_embeddings > max_position_embeddings: + ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) + else: + ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) + inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim + inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) + + return inv_freq, attention_factor + + +def _compute_llama3_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies for llama 3.1. + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + + factor = config.rope_scaling["factor"] # `8` in the original implementation + low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation + high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation + old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama, attention_factor + + +# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters +# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE +# parameterizations, as long as the callable has the same signature. +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "linear": _compute_linear_scaling_rope_parameters, + "dynamic": _compute_dynamic_ntk_parameters, + "yarn": _compute_yarn_parameters, + "longrope": _compute_longrope_parameters, + "llama3": _compute_llama3_parameters, +} + + +def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): + """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's gracefully handle it + if "rope_type" not in received_keys and "type" in received_keys: + received_keys -= {"type"} + received_keys.add("rope_type") + + missing_keys = required_keys - received_keys + if missing_keys: + raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") + + if optional_keys is not None: + unused_keys = received_keys - required_keys - optional_keys + else: + unused_keys = received_keys - required_keys + if unused_keys: + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + + +def _validate_default_rope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys) + + +def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + +def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + +def _validate_yarn_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) + + +def _validate_longrope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "short_factor", "long_factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + short_factor = rope_scaling.get("short_factor") + if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): + logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") + if not len(short_factor) == dim // 2: + logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") + + long_factor = rope_scaling.get("long_factor") + if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): + logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") + if not len(long_factor) == dim // 2: + logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") + + # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over + # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is + # unique to longrope (= undesirable) + if hasattr(config, "original_max_position_embeddings"): + logger.warning_once( + "This model has set a `original_max_position_embeddings` field, to be used together with " + "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`" + "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " + "as it is compatible with most model architectures." + ) + else: + factor = rope_scaling.get("factor") + if factor is None: + logger.warning("Missing required keys in `rope_scaling`: 'factor'") + elif not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + + +def _validate_llama3_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + if low_freq_factor is None or not isinstance(low_freq_factor, float): + logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") + if high_freq_factor is None or not isinstance(high_freq_factor, float): + logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") + if high_freq_factor <= low_freq_factor: + logger.warning( + "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" + f"{high_freq_factor} and low_freq_factor={low_freq_factor}" + ) + + original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] + if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " + f"{original_max_position_embeddings}" + ) + if original_max_position_embeddings >= config.max_position_embeddings: + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " + f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" + ) + + +# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. +ROPE_VALIDATION_FUNCTIONS = { + "default": _validate_default_rope_parameters, + "linear": _validate_linear_scaling_rope_parameters, + "dynamic": _validate_dynamic_scaling_rope_parameters, + "yarn": _validate_yarn_parameters, + "longrope": _validate_longrope_parameters, + "llama3": _validate_llama3_parameters, +} + + +def rope_config_validation(config: PretrainedConfig): + """ + Validate the RoPE config arguments, given a `PretrainedConfig` object + """ + rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` + if rope_scaling is None: + return + + # BC: "rope_type" was originally "type" + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) + if validation_fn is not None: + validation_fn(config) + else: + logger.warning( + f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" + ) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py new file mode 100644 index 000000000..1057f847e --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import nn + +ALL_LAYERNORM_LAYERS = [nn.LayerNorm] \ No newline at end of file diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py new file mode 100644 index 000000000..f34205a39 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py @@ -0,0 +1,448 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Llama4VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4VisionModel`]. It is used to instantiate a + Llama4 vision model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + num_hidden_layers (`int`, *optional*, defaults to 34): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + vision_output_dim (`int`, *optional*, defaults to 7680): + Dimensionality of the vision model output. Includes output of transformer + encoder with intermediate layers and global transformer encoder. + image_size (`int`, *optional*, defaults to 448): + The size (resolution) of each image *tile*. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + vision_feature_layer (``, *optional*, defaults to -1): TODO + vision_feature_select_strategy (`int`, *optional*, defaults to `"default"`): TODO + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + pixel_shuffle_ratio (`int`, *optional*, defaults to 0.5): TODO + projector_input_dim (`int`, *optional*, defaults to 4096): TODO + projector_output_dim (`int`, *optional*, defaults to 4096): TODO + multi_modal_projector_bias (`int`, *optional*, defaults to `False`): TODO + projector_dropout (`int`, *optional*, defaults to 0.0): TODO + attention_dropout (`int`, *optional*, defaults to 0.0): TODO + rope_theta (`int`, *optional*, defaults to 10000): TODO + """ + + base_model_tp_plan = { + "model.layers.*.self_attn.q_proj": "colwise", + "model.layers.*.self_attn.k_proj": "colwise", + "model.layers.*.self_attn.v_proj": "colwise", + "model.layers.*.self_attn.o_proj": "rowwise", + "vision_adapter.mlp.fc1": "colwise", + "vision_adapter.mlp.fc2": "rowwise", + "patch_embedding.linear": "colwise_rep", + } + model_type = "llama4_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size: int = 768, + hidden_act: str = "gelu", + num_hidden_layers: int = 34, + num_attention_heads: int = 16, + num_channels: int = 3, + intermediate_size: int = 5632, + vision_output_dim: int = 7680, + image_size: int = 448, + patch_size: int = 14, + norm_eps: float = 1e-5, + vision_feature_layer=-1, + vision_feature_select_strategy="default", + initializer_range: float = 0.02, + pixel_shuffle_ratio=0.5, + projector_input_dim=4096, + projector_output_dim=4096, + multi_modal_projector_bias=False, + projector_dropout=0.0, + attention_dropout=0.0, + rope_theta=10000, + **kwargs, + ): + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.num_channels = num_channels + self.intermediate_size = intermediate_size + self.image_size = image_size + self.vision_output_dim = vision_output_dim + self.patch_size = patch_size + self.norm_eps = norm_eps + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.pixel_shuffle_ratio = pixel_shuffle_ratio + self.projector_input_dim = projector_input_dim + self.projector_output_dim = projector_output_dim + self.multi_modal_projector_bias = multi_modal_projector_bias + self.projector_dropout = projector_dropout + self.attention_dropout = attention_dropout + self.vision_feature_layer = vision_feature_layer + self.vision_feature_select_strategy = vision_feature_select_strategy + self.rope_theta = rope_theta + super().__init__(**kwargs) + + +class Llama4TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4TextModel`]. It is used to instantiate a + Llama4 text model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 202048): + Vocabulary size of the Llama4 text model. Defines the maximum number of different tokens that can be represented + by the `inputs_ids` passed when calling [`Llama4TextModel`]. + hidden_size (`int`, *optional*, defaults to 5120): + Dimensionality of the embeddings and hidden states. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + intermediate_size_mlp (`int`, *optional*, defaults to 16384): TODO + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 40): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If not + specified, will default to `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 128): TODO + hidden_act (`str` or `Callable`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the encoder and pooler. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions. + pad_token_id (`int`, *optional*, defaults to 128004): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the beginning of sentence token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end of sentence token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to `500000.0`): + The base period of the RoPE embeddings. + attention_dropout (`int`, *optional*, defaults to 0.0): TODO + num_experts_per_tok (`int`, *optional*, defaults to 1): TODO + num_local_experts (`int`, *optional*, defaults to 16): TODO + moe_layers (`int`, *optional*): TODO + interleave_moe_layer_step (`int`, *optional*, defaults to 1): TODO + use_qk_norm (`int`, *optional*, defaults to `True`): TODO + output_router_logits (`int`, *optional*, defaults to `False`): TODO + router_aux_loss_coef (`int`, *optional*, defaults to 0.001): TODO + router_jitter_noise (`int`, *optional*, defaults to 0.0): TODO + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + + no_rope_layers (`int`, *optional*): TODO + no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO + attention_chunk_size (`int`, *optional*, defaults to 8192): + + attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO + floor_scale (`int`, *optional*, defaults to 8192): TODO + attn_scale (`int`, *optional*, defaults to 0.1): TODO + cache_implementation (``, *optional*, defaults to `"hybrid"`): + + Example: + """ + + model_type = "llama4_text" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.input_layernorm.weight": "sequence_parallel", + "layers.*.post_attention_layernorm.weight": "sequence_parallel", + "norm.weight": "sequence_parallel", + "layers.*.feed_forward.shared_expert.gate_proj": "local_colwise", + "layers.*.feed_forward.shared_expert.up_proj": "local_colwise", + "layers.*.feed_forward.shared_expert.down_proj": "local_rowwise", + "layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", # row because not linear + "layers.*.feed_forward.experts.down_proj": "local_colwise", # col because not linear + "layers.*.feed_forward.experts": "local", + "layers.*.feed_forward.gate_proj": "local_colwise", + "layers.*.feed_forward.up_proj": "local_colwise", + "layers.*.feed_forward.down_proj": "local_rowwise", + "layers.*.feed_forward": "gather", + } + + def __init__( + self, + vocab_size=202048, + hidden_size=5120, + intermediate_size=8192, + intermediate_size_mlp=16384, + num_hidden_layers=48, + num_attention_heads=40, + num_key_value_heads=8, + head_dim=128, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=500000, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=16, + moe_layers=None, + interleave_moe_layer_step=1, + use_qk_norm=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + rope_scaling=None, + no_rope_layers=None, + no_rope_layer_interval=4, + attention_chunk_size=8192, + attn_temperature_tuning=4, + floor_scale=8192, + attn_scale=0.1, + cache_implementation="hybrid", + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.attn_temperature_tuning = attn_temperature_tuning + self.attn_scale = attn_scale + self.floor_scale = floor_scale + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.intermediate_size_mlp = intermediate_size_mlp + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.rope_scaling = rope_scaling + self.attention_bias = False + self.cache_implementation = cache_implementation + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.use_qk_norm = use_qk_norm + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + default_no_rope_layers = [ + int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers) + ] + + # no_rope_layers == [] is invalid as we cannot have 0 layers + self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers + + self.interleave_moe_layer_step = interleave_moe_layer_step + self.moe_layers = ( + moe_layers + if moe_layers is not None + else list(range(interleave_moe_layer_step - 1, num_hidden_layers, interleave_moe_layer_step)) + ) + self.attention_chunk_size = attention_chunk_size + + +class Llama4Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4Model`]. It is used to instantiate an + Llama4 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vision_config (`Llama4VisionConfig`, *optional*): + The Llama4 Vision config. + text_config (`Llama4TextConfig`, *optional*): + The Llama4 Text config. + boi_token_index (`int`, *optional*, defaults to 200080): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 200081): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 200092): + The image token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + + ```python + >>> from transformers import Llama4Model, Llama4Config + + >>> # Initializing a Llama4 7B style configuration + >>> configuration = Llama4Config() + + >>> # Initializing a model from the Llama4 7B style configuration + >>> model = Llama4Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama4" + sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig} + base_model_tp_plan = { + "multi_modal_projector.linear_1": "colwise_rep", + } + + def __init__( + self, + vision_config=None, + text_config=None, + boi_token_index=200080, + eoi_token_index=200081, + image_token_index=200092, + tie_word_embeddings=False, + **kwargs, + ): + if vision_config is None: + self.vision_config = Llama4VisionConfig() + logger.info("vision_config is None, using default llama4 vision config") + elif isinstance(vision_config, dict): + self.vision_config = Llama4VisionConfig(**vision_config) + elif isinstance(vision_config, Llama4VisionConfig): + self.vision_config = vision_config + + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + if text_config is None: + self.text_config = Llama4TextConfig() + logger.info("text_config is None, using default llama4 text config") + elif isinstance(text_config, dict): + self.text_config = Llama4TextConfig(**text_config) + elif isinstance(text_config, Llama4TextConfig): + self.text_config = text_config + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["Llama4Config", "Llama4TextConfig", "Llama4VisionConfig"] diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py new file mode 100644 index 000000000..122e13447 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py @@ -0,0 +1,289 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.processing_utils import Unpack +from transformers.utils import ( + is_torch_flex_attn_available, + logging, +) +from .transformers_4_51_3__configuration_llama4 import Llama4TextConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from transformers.integrations.flex_attention import make_flex_block_causal_mask + +logger = logging.get_logger(__name__) + + +class Llama4TextL2Norm(torch.nn.Module): + def __init__(self, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self._norm(x.float()).type_as(x) + + def extra_repr(self): + return f"eps={self.eps}" + + +class Llama4TextRotaryEmbedding(nn.Module): + def __init__(self, config: Llama4TextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + self.rope_type = "llama3" if config.rope_scaling is not None else "default" + + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + freqs_cis = freqs_cis * self.attention_scaling + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # print(f"{module.layer_idx=} {module.num_key_value_groups=}") + # print(f"{module.layer_idx=} {module.head_dim=}") + # print(f"{module.layer_idx=} {module.training=}") + # print(f"{scaling=}") + # print(f"{dropout=}") + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Llama4TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Llama4TextConfig, layer_idx, use_rope: bool): # we added use_rope to not be dependent on the layer index + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_attention_heads = config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attn_scale = config.attn_scale + self.floor_scale = config.floor_scale + self.attn_temperature_tuning = config.attn_temperature_tuning + self.attention_dropout = config.attention_dropout + self.is_causal = True + # self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers + self.use_rope = use_rope + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + if self.config.use_qk_norm and self.use_rope: + self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if self.use_rope: # the 16E model skips rope for long context on certain layers + query_states, key_states = apply_rotary_emb( + query_states, key_states, position_embeddings.to(query_states.device) + ) + + if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm + query_states = self.qk_norm(query_states) + key_states = self.qk_norm(key_states) + + # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers + if self.attn_temperature_tuning and not self.use_rope: + device = query_states.device + attn_scales = ( + torch.log(torch.floor((cache_position.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0 + ).to(device) + attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1 + query_states = (query_states * attn_scales).to(query_states.dtype) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + # print(f"{self.layer_idx=} {cache_position=} {attention_mask=}") + # print(f"{self.layer_idx=} {query_states.flatten()[:10]=}") + # print(f"{self.layer_idx=} {key_states.flatten()[:10]=}") + # print(f"{self.layer_idx=} {value_states.flatten()[:10]=}") + # print(f"{self.layer_idx=} {kwargs=}") + # print(f"{self.layer_idx=} {attention_interface=}") + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights From b3ed5bc9c88a4c28e860125c22fe0d86442daecf Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 3 Nov 2025 21:46:33 +0100 Subject: [PATCH 05/49] Add decilm modelling code Signed-off-by: Daniel Korzekwa --- .../decilm/deci_lm_hf_code/__init__.py | 15 ++ .../decilm/deci_lm_hf_code/vllm_yarn_utils.py | 210 ++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/__init__.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/vllm_yarn_utils.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/__init__.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/__init__.py new file mode 100644 index 000000000..47f1c65a1 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/vllm_yarn_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/vllm_yarn_utils.py new file mode 100644 index 000000000..4c8f86cdb --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/vllm_yarn_utils.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn + + +def _apply_rotary_emb_torch( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: int | float) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048 +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +def _yarn_find_correction_range( + low_rot: int, high_rot: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048 +) -> tuple[int, int]: + low = math.floor(_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, dtype: torch.dtype) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, self.beta_slow, self.rotary_dim, self.base, self.max_position_embeddings + ) + # print(f"low: {low}, high: {high}") + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + return cache From a700da55191d3d1cfc584299440dbeea6e241677 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 3 Nov 2025 21:57:35 +0100 Subject: [PATCH 06/49] Add decilm modelling code Signed-off-by: Daniel Korzekwa --- .../decilm/deci_lm_hf_code/variable_cache.py | 213 ++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/variable_cache.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/variable_cache.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/variable_cache.py new file mode 100644 index 000000000..9acc27eb9 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/variable_cache.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +from copy import deepcopy +from typing import Any + +import torch +from transformers.cache_utils import ( + Cache, # used to let GenerationMixin know that we use a Cache object +) + +from .configuration_decilm import DeciLMConfig +from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2 +from .transformers_4_44_2__cache_utils import SinkCache, SlidingWindowCache, StaticCache +from .transformers_4_51_3__cache_utils import HybridChunkedCache + +LayerIndex = tuple[ + int, ... +] # supports both regular transformer blocks and parallel transformer multi-blocks + + +class VariableCache(Cache_4_44_2, Cache): + """ + A Cache object that supports a different Cache implementation for every layer, + including layers without any kv-cache. + Implemented using a list of Cache objects, each represents a "model" with 1 layer. + The default implementation for the layer caches is StaticCache. + The cache of each layer is allocated to the same gpu as the layer itself. + """ + + def __init__( + self, + *, # key-word only, no positional args allowed to avoid mix-ups with newer transformers versions + config: DeciLMConfig, + batch_size: int | None = None, + max_cache_len: int | None = None, + dtype: torch.dtype = torch.get_default_dtype(), + max_batch_size: int | None = None, + **kwargs, + ) -> None: + Cache_4_44_2.__init__(self) + + self.config = deepcopy(config) + self.max_batch_size = batch_size or max_batch_size + self.batch_size = self.max_batch_size + self.max_cache_len = ( + config.max_position_embeddings if (max_cache_len is None) else max_cache_len + ) + self.dtype = dtype + + self.layer_caches: dict[LayerIndex, Cache_4_44_2] = {} + self.layer_devices: dict[LayerIndex, torch.device] = {} + + def __repr__(self): + return ( + f"VariableCache:\n" + f"==============\n" + f"max_batch_size={self.max_batch_size}\n" + f"batch_size={self.batch_size}\n" + f"max_cache_len={self.max_cache_len}\n" + f"dtype={self.dtype}\n" + f"layer_caches={self.layer_caches}\n" + f"layer_devices={self.layer_devices}\n" + f"==============\n" + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int | LayerIndex, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if isinstance(layer_idx, int): + layer_idx = _int_to_layer_index(layer_idx) + + if layer_idx not in self.layer_caches: + self.layer_devices[layer_idx] = key_states.device + self._init_layer_cache(layer_idx) + + layer_cache = self.layer_caches[layer_idx] + assert layer_cache is not None, ( + f"Trying to update the cache of a cache-less layer: {layer_idx=}" + ) + + k_out, v_out = layer_cache.update( + key_states=key_states, value_states=value_states, layer_idx=0, cache_kwargs=cache_kwargs + ) + + input_seq_len = key_states.shape[2] # [batch_size, num_kv_heads, seq_len, hidden_size] + cache_seq_len = self.get_seq_length(layer_idx) + seq_len = max(input_seq_len, cache_seq_len) + + k_out = k_out[:, :, :seq_len, :] + v_out = v_out[:, :, :seq_len, :] + return k_out, v_out + + def _init_layer_cache(self, layer_idx: LayerIndex) -> None: + block_config = self.config.get_block_config(layer_idx) + attention_config = block_config.attention + + if attention_config.no_op or attention_config.replace_with_linear: + return None + + device = self.layer_devices[layer_idx] + assert device is not None, f"Trying to init layer cache for {layer_idx=} without device" + + config = deepcopy(self.config) + config.num_hidden_layers = 1 + config.num_key_value_heads = ( + self.config.num_attention_heads // attention_config.n_heads_in_group + ) + + if attention_config.is_llama4: + attention_chunk_size = attention_config.llama4.attention_chunk_size + is_chunked = attention_chunk_size is not None + config.no_rope_layers = [int(is_chunked)] + config.attention_chunk_size = ( + attention_chunk_size if is_chunked else config.get_min_attention_chunk_size() + ) + self.layer_caches[layer_idx] = HybridChunkedCache( + config=config, + max_batch_size=self.max_batch_size, + max_cache_len=self.max_cache_len, + dtype=self.dtype, + ) + return + + if attention_config.window_length is not None: + if not attention_config.is_sink: + config.sliding_window = attention_config.window_length + self.layer_caches[layer_idx] = SlidingWindowCache( + config=config, + max_batch_size=self.max_batch_size, + max_cache_len=self.max_cache_len, + device=device, + dtype=self.dtype, + ) + return + elif not attention_config.unshifted_sink: + self.layer_caches[layer_idx] = SinkCache( + window_length=attention_config.window_length, + num_sink_tokens=attention_config.num_sink_tokens, + ) + return + + self.layer_caches[layer_idx] = StaticCache( + config=config, + max_batch_size=self.max_batch_size, + max_cache_len=self.max_cache_len, + device=device, + dtype=self.dtype, + ) + + def _get_arbitrary_cache(self) -> Cache_4_44_2: + if len(self.layer_caches) == 0: + raise NoCacheFoundError() + layer_cache = next(iter(self.layer_caches.values())) + return layer_cache + + def get_seq_length(self, layer_idx: int | LayerIndex | None = 0) -> int: + """default 0 to match standard HF implementation""" + if (layer_idx is None) or ( + layer_idx == 0 and _int_to_layer_index(0) not in self.layer_caches + ): + try: + layer_cache = self._get_arbitrary_cache() + return layer_cache.get_seq_length() + except NoCacheFoundError: + return 0 + + if isinstance(layer_idx, int): + layer_idx = _int_to_layer_index(layer_idx) + + layer_cache = self.layer_caches[layer_idx] + return layer_cache.get_seq_length() + + def get_max_length(self) -> int | None: + """Returns the maximum sequence length of the cached states.""" + return self.max_cache_len + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def reset(self): + for layer_idx, layer_cache in self.layer_caches.items(): + if hasattr(layer_cache, "reset"): + layer_cache.reset() + else: + self.layer_caches[layer_idx] = None + self.layer_devices[layer_idx] = None + # self._init_layer_cache(layer_idx) + + +class NoCacheFoundError(Exception): + pass + + +def _int_to_layer_index(layer_idx: int) -> LayerIndex: + return (layer_idx,) From b59b679311c867a6411b3009a0c5da899b2ca5aa Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 4 Nov 2025 08:51:33 +0100 Subject: [PATCH 07/49] Correct licence headers Signed-off-by: Daniel Korzekwa --- .pre-commit-config.yaml | 1 + .../transformers_4_44_2__activations.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f432dc98..eec84b2b8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -107,6 +107,7 @@ repos: examples/speculative_decoding/main.py| examples/speculative_decoding/medusa_utils.py| examples/speculative_decoding/server_generate.py| + modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_.*\.py| )$ # Default hook for Apache 2.0 in c/c++/cuda files diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py index 8b4810c5d..6c964dbfc 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__activations.py @@ -1,25 +1,25 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, From 1abdf3e2945cc5a24c9ee38da1ff31e5c30764b7 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 4 Nov 2025 08:53:56 +0100 Subject: [PATCH 08/49] Correct licence headers Signed-off-by: Daniel Korzekwa --- ...ransformers_4_44_2__configuration_llama.py | 23 ++++++++++--------- ...ormers_4_44_2__modeling_attn_mask_utils.py | 13 ++++++----- ...g_flash_attention_utils_backward_compat.py | 14 +++++------ .../transformers_4_44_2__modeling_outputs.py | 12 +++++----- ...ransformers_4_44_2__modeling_rope_utils.py | 12 +++++----- .../transformers_4_44_2__pytorch_utils.py | 12 +++++----- .../transformers_4_51_3__cache_utils.py | 1 - ...ansformers_4_51_3__configuration_llama4.py | 15 ++++++------ ...rmers_4_51_3__modeling_llama4_attention.py | 14 +++++------ 9 files changed, 58 insertions(+), 58 deletions(-) diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py index 0098de411..461996f74 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__configuration_llama.py @@ -1,37 +1,38 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """LLaMA model configuration""" from transformers.configuration_utils import PretrainedConfig diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py index 76631f758..725780067 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py @@ -1,31 +1,32 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass from typing import List, Optional, Tuple, Union diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py index 245184a1c..9e9fb46ca 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py @@ -1,26 +1,26 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# coding=utf-8 +# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# coding=utf-8 -# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py index dddf0c9e6..aa9f07b87 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py @@ -1,25 +1,25 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py index 0a59929aa..761c2b640 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_rope_utils.py @@ -1,25 +1,25 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py index 1057f847e..a1b413b0e 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py @@ -1,25 +1,25 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py index ebcebdebe..3dac4a51c 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # mypy: ignore-errors import copy import importlib.metadata diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py index f34205a39..7dc65a092 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__configuration_llama4.py @@ -1,27 +1,27 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# coding=utf-8 -# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. # +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -29,7 +29,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py index 122e13447..b17883628 100644 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py @@ -1,27 +1,27 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# coding=utf-8 -# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. # +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, From 66609b1fb0870d5d4e9ff05164dc3be51cb79e89 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 4 Nov 2025 09:35:55 +0100 Subject: [PATCH 09/49] Add decilm code Signed-off-by: Daniel Korzekwa --- .../deci_lm_hf_code/tokenization_mistral.py | 374 ++++++++++++++++++ 1 file changed, 374 insertions(+) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py new file mode 100644 index 000000000..e67674a09 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py @@ -0,0 +1,374 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Based on https://github.com/vllm-project/vllm/blob/739e03b3449a7f3b0a81ebc30b9555305d914e2d/vllm/transformers_utils/tokenizers/mistral.py +# mypy: ignore-errors + +import os +import re +import sys +from pathlib import Path +from shutil import copyfile +from typing import TYPE_CHECKING, Any + +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + +if TYPE_CHECKING: + from mistral_common.protocol.instruct.request import ChatCompletionRequest + +logger = logging.get_logger(__name__) + + +def _called_from_vllm() -> bool: + frame = sys._getframe(1) + while frame: + mod = frame.f_globals.get("__name__", "") + if mod == "vllm" or mod.startswith("vllm."): + return True + frame = frame.f_back + return False + + +class HFAdaptedMistralTokenizer(PreTrainedTokenizer): + """ + In order to save the tokenizer, do the following: + ``` + # from import HFAdaptedMistralTokenizer + # from mistral_common.tokens.tokenizers.base import SpecialTokens + HFAdaptedMistralTokenizer.register_for_auto_class("AutoTokenizer") + tokenizer = HFAdaptedMistralTokenizer("", chat_template="dummy") + tokenizer.add_special_tokens( + {"additional_special_tokens": [v.value for _, v in SpecialTokens.__members__.items()]} + ) + tokenizer.save_pretrained("") + ``` + """ + + vocab_files_names = {"path_indicator": "tokenizer_config.json"} + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + path_indicator: str, + unk_token: str | None = None, + bos_token: str | None = None, + eos_token: str | None = None, + pad_token: str | None = None, + add_bos_token: bool = True, + add_eos_token: bool = False, + clean_up_tokenization_spaces: bool = False, + **kwargs, + ): + path_indicator: Path = Path(path_indicator) + if path_indicator.name == "tokenizer_config.json": + path_indicator = path_indicator.parent + if path_indicator.is_dir(): + tokenizer_file_name = _find_tokenizer_file(os.listdir(path_indicator)) + tokenizer_file = str(path_indicator / tokenizer_file_name) + else: + tokenizer_file = path_indicator + self._mistral_tokenizer_path = str(tokenizer_file) + + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer as MistralTokenizer + + self._mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file) + self._instruct_tokenizer = self._mistral_tokenizer.instruct_tokenizer + + # Copied from https://github.com/patrickvonplaten/vllm/blob/6cca3d8c330e169bbf386561c441ca5f3879cf85/vllm/transformers_utils/tokenizers/mistral.py + self.version: int = int( + self._instruct_tokenizer.tokenizer.version.value.split("v")[-1].split("m")[0] + ) + + tokenizer_ = self._instruct_tokenizer.tokenizer + from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer + + self.is_tekken = isinstance(tokenizer_, Tekkenizer) + from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer + + self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) + if self.is_tekken: + # Make sure special tokens will not raise + tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE + elif self.is_spm: + pass + else: + raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") + + self._vocab = tokenizer_.vocab() + # Convert to a Dict[str, int] to match protocol, but this is a lossy + # conversion. There may be multiple token ids that decode to the same + # string due to partial UTF-8 byte sequences being converted to � + self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)} + self._tokenizer = tokenizer_ + self._max_token_id = self.vocab_size - 1 + self.vocab = self._vocab_dict + + bos_token = ( + bos_token + if bos_token + else AddedToken( + self._tokenizer._vocab[self._tokenizer.bos_id], + normalized=False, + special=True, + ) + ) + eos_token = ( + eos_token + if eos_token + else AddedToken( + self._tokenizer._vocab[self._tokenizer.eos_id], + normalized=False, + special=True, + ) + ) + unk_token = ( + unk_token + if unk_token + else AddedToken( + self._tokenizer._vocab[self._tokenizer.unk_id], + normalized=False, + special=True, + ) + ) + pad_token = ( + pad_token + if pad_token + else AddedToken( + self._tokenizer._vocab[self._tokenizer.pad_id], + normalized=False, + special=True, + ) + ) + + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + + self._in_vllm = _called_from_vllm() + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def vocab_size(self): + """Returns vocab size""" + return self._tokenizer.n_words + + def get_vocab(self): + """Returns vocab as a dict""" + return self._vocab_dict + + def tokenize( + self, + text: str, + pair: str | None = None, + add_special_tokens: bool | None = None, + **kwargs, + ) -> list[str]: + from mistral_common.tokens.tokenizers.base import SpecialTokens + + if add_special_tokens is None: + bos = self._add_bos_token + eos = self._add_eos_token + else: + bos = add_special_tokens + eos = add_special_tokens + + input_ids = [] + parts = self.tokens_trie.split(text) + + in_vllm_chat_completion_mode = False + if ( + self._in_vllm + and len(parts) > 1 + and parts[0] == SpecialTokens.bos.value + and parts[1] == SpecialTokens.begin_inst.value + ): + # This is a dangerous hack to make the tokenizer work with vLLM. + # It means we are in chat completion mode. + bos = False + eos = False + in_vllm_chat_completion_mode = True + + if os.environ.get("HF_TOKENIZE_FORCE_NO_SPECIAL_TOKENS", "0") == "1": + bos = False + eos = False + + if not self._in_vllm or in_vllm_chat_completion_mode: + for part in parts: + if part in self.additional_special_tokens and part in self._vocab_dict: + input_ids.append(self._convert_token_to_id(part)) + else: + input_ids.extend(self._tokenizer.encode(part, bos=bos, eos=eos)) + else: + # Doesn't tokenize special tokens properly, but this is the behavior of vLLM when we are in completion mode. + input_ids = self._tokenizer.encode(text, bos=bos, eos=eos) + + if os.environ.get("HF_TOKENIZE_ABUSE", "1") == "1": + # A lot faster than the other option + return input_ids + else: + return [self._convert_id_to_token(token_id) for token_id in input_ids] + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + if len(tokens) > 0 and isinstance(tokens[0], int): + return tokens + return super().convert_tokens_to_ids(tokens) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self._vocab_dict[token] + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + piece = self._tokenizer.id_to_piece(index) + return piece if isinstance(piece, str) else piece.value + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + from mistral_common.tokens.tokenizers.base import SpecialTokens + + if self.is_tekken: + tokens = [ + t + for t in tokens + if (t is SpecialTokens.tool_calls or t not in self._tokenizer._all_special_tokens) + ] + + if any(isinstance(t, bytes) for t in tokens): + # we need to encode and decode all tokens again + shift = self._tokenizer.num_special_tokens + + def _token_to_id(t: str): + t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t + try: + return shift + self._tokenizer._tekken_token2id_nospecial[t_bytes] + except KeyError: + logger.warning( + "Failed to convert token %s to id, replacing with ", + t_bytes, + ) + return self._tokenizer.unk_id + + ids = [_token_to_id(t) for t in tokens] + decoded = self._tokenizer.decode(ids) + else: + decoded = "".join(tokens) + else: + # make sure certain special tokens like Tool calls are + # not decoded + special_tokens = {SpecialTokens.tool_calls} + regular_tokens: list[str] = [] + decoded_list = [] + + for token in tokens: + if token in special_tokens: + if regular_tokens: + decoded_list.append(self._tokenizer.decode(regular_tokens)) + regular_tokens = [] + decoded_list.append(token) + else: + regular_tokens.append(token) + + if regular_tokens: + decoded_list.append(self._tokenizer.decode(regular_tokens)) # type: ignore[no-untyped-call] + + decoded = "".join(decoded_list) + + return decoded + + def save_vocabulary(self, save_directory, filename_prefix: str | None = None) -> tuple[str]: + """ + Use this method to save the full tokenizer file. + """ + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join(save_directory, "tekken.json") + + if os.path.abspath(self._mistral_tokenizer_path) != os.path.abspath(out_vocab_file): + copyfile(self._mistral_tokenizer_path, out_vocab_file) + + return (out_vocab_file,) + + def apply_chat_template( + self, + conversation: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + tokenize: bool = True, + **kwargs, + ) -> list[int]: + request = _make_mistral_chat_completion_request(conversation, tools) + encoded = self._mistral_tokenizer.encode_chat_completion(request) + if tokenize: + # encode-decode to get clean prompt + return encoded.tokens + else: + return encoded.text + + +def _find_tokenizer_file(files: list[str]): + file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$") + + matched_files = [file for file in files if file_pattern.match(file)] + if len(matched_files) > 1: + raise OSError( + f"Found {len(matched_files)} files matching the " + f"pattern: `{file_pattern.pattern}`. Make sure only one Mistral " + f"tokenizer is present in {files}." + ) + elif len(matched_files) == 0: + raise OSError( + f"Found {len(matched_files)} files matching the " + f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral " + f"tokenizer is present in {files}." + ) + + return matched_files[0] + + +def _make_mistral_chat_completion_request( + messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None +) -> "ChatCompletionRequest": + last_message = messages[-1] + if last_message["role"] == "assistant": + last_message["prefix"] = True + + # mistral-common requires AssistantMessage content to be string [1]. + # + # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 + for message in messages: + if message.get("role") == "assistant": + content = message.get("content") + if isinstance(content, list): + content = "\n".join(chunk.get("text") for chunk in content) + message["content"] = content + + # The Mistral client, in comparison to the OpenAI client, requires the + # "parameters" dict to be present, even if it's empty. + if tools: + for function in [tool["function"] for tool in tools if tool["type"] == "function"]: + if function.get("parameters") is None: + function["parameters"] = {} + + from mistral_common.protocol.instruct.request import ChatCompletionRequest + + return ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var] From 7da0a8a164b1c844b941ced925fb40d6fab96a98 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 4 Nov 2025 10:01:44 +0100 Subject: [PATCH 10/49] Add decilm code Signed-off-by: Daniel Korzekwa --- .../deci_lm_hf_code/megatron_lm__tokenizer.py | 187 +++++++++++++++++ .../deci_lm_hf_code/tokenization_decilm.py | 195 ++++++++++++++++++ 2 files changed, 382 insertions(+) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_decilm.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py new file mode 100644 index 000000000..5c641d25b --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__tokenizer.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +"""Megatron tokenizers.""" + +import base64 +import json +from pathlib import Path + +from .megatron_lm__megatron_tokenizer import MegatronTokenizer + + +def reload_mergeable_ranks( + path: str, + max_vocab: int | None = None, +) -> dict[bytes, int]: + """ + Reload our tokenizer JSON file and convert it to Tiktoken format. + """ + assert path.endswith(".json") + + # reload vocab + with open(path) as f: + vocab = json.load(f) + assert isinstance(vocab, list) + print(f"Vocab size: {len(vocab)}") + if max_vocab is not None: + vocab = vocab[:max_vocab] + print(f"Cutting vocab to first {len(vocab)} tokens.") + + # build ranks + ranks: dict[bytes, int] = {} + for i, x in enumerate(vocab): + assert x.keys() == {"rank", "token_bytes", "token_str"} + assert x["rank"] == i + merge = base64.b64decode(x["token_bytes"]) + assert i >= 256 or merge == bytes([i]) + ranks[merge] = x["rank"] + + # sanity check + assert len(ranks) == len(vocab) + assert set(ranks.values()) == set(range(len(ranks))) + + return ranks + + +PATTERN_TIKTOKEN = ( + r"[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" +) +PATTERN_TIKTOKEN_V2 = ( + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+" + "|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*" + "|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +) + + +class CustomTikTokenizer(MegatronTokenizer): + def __init__( + self, + path: str, + pattern: str, + vocab_size: int, + num_special_tokens: int, + special_tokens: list[str] | None, + ): + super().__init__( + path, + pattern=pattern, + vocab_size=vocab_size, + num_special_tokens=num_special_tokens, + special_tokens=special_tokens, + ) + import tiktoken + + # if vocab_size is None: + # vocab_size = 2**17 # Fallback vocab size is 131072. + self._vocab_size = vocab_size + + special_tokens_default = ["", "", ""] + if special_tokens is None: + special_tokens = special_tokens_default.copy() + assert len(special_tokens) == len(set(special_tokens)), ( + f"Special tokens should be unique: {special_tokens}" + ) + assert len(special_tokens) <= num_special_tokens < self._vocab_size + assert set(special_tokens_default) <= set(special_tokens), ( + f"Custom special tokens should include {special_tokens_default}" + ) + + special_filler = [ + "".format(id=i) for i in range(len(special_tokens), num_special_tokens) + ] + if special_filler: + print(f"Adding special tokens {special_filler[0]}, ..., {special_filler[-1]}") + special_tokens = special_tokens + special_filler + assert len(set(special_tokens)) == len(special_tokens) == num_special_tokens, special_tokens + inner_vocab_size = self._vocab_size - num_special_tokens + + token_to_id_without_special_tokens = reload_mergeable_ranks( + path, max_vocab=inner_vocab_size + ) + # Create space for special tokens. + token_to_id_without_special_tokens = { + t: i + num_special_tokens for t, i in token_to_id_without_special_tokens.items() + } + + special_tokens = {t: i for i, t in enumerate(special_tokens)} + self._unk_id = special_tokens[""] + self._bos_id = special_tokens[""] + self._eos_id = special_tokens[""] + + # Create tiktoken model. + self._model = tiktoken.Encoding( + name=Path(path).parent.name, + pat_str=pattern, + mergeable_ranks=token_to_id_without_special_tokens, + special_tokens=special_tokens, + ) + + # Create final _id_to_token and _token_to_id data structures with special tokens inserted + # into appropriate locations. + assert set(token_to_id_without_special_tokens.keys()).isdisjoint(set(special_tokens.keys())) + self._token_to_id = token_to_id_without_special_tokens.copy() + self._token_to_id.update(special_tokens) + self._id_to_token = {v: k for k, v in self._token_to_id.items()} + assert set(range(self._vocab_size)) == set(self._id_to_token.keys()) + + @property + def bos(self) -> int: + return self._bos_id + + @property + def eos(self) -> int: + return self._eos_id + + @property + def unk(self) -> int: + return self._unk_id + + @property + def eod(self) -> int: + return self._eos_id + + @property + def vocab(self): + return self._token_to_id + + @property + def inv_vocab(self): + return self._id_to_token + + def tokenize(self, s: str, bos: bool = False, eos: bool = False) -> list[int]: + tokens = self._model.encode_ordinary(s) + if bos: + tokens = [self.bos, *tokens] + if eos: + tokens = [*tokens, self.eos] + + return tokens + + def detokenize(self, tokens: list[int]) -> str: + return self._model.decode(tokens) + + @property + def vocab_size(self) -> int: + return self._vocab_size + + @property + def encoder(self): + return self._token_to_id + + @property + def decoder(self): + return self._id_to_token diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_decilm.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_decilm.py new file mode 100644 index 000000000..14c840b8b --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_decilm.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Only needed for DeciLM models that use Megatron tokenizers. +DeciLM models that use Llama tokenizers do not need external code. +""" + +import json +import os +from pathlib import Path +from typing import Literal + +from transformers import PreTrainedTokenizer +from transformers.dynamic_module_utils import custom_object_save +from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE, AddedToken + +from .megatron_lm__megatron_tokenizer import ( + MegatronTokenizer, # fake import to make AutoTokenizer infer the dependency +) +from .megatron_lm__tokenizer import PATTERN_TIKTOKEN, PATTERN_TIKTOKEN_V2, CustomTikTokenizer + +MegatronTokenizer # make sure that auto-formatting doesn't remove the import + + +class MegatronTikTokenizer(PreTrainedTokenizer): + vocab_files_names: dict[str, str] = {"vocab_file": "tiktoken_vocab.json"} + model_input_names: list[str] = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file: str, + tiktoken_pattern: Literal["v1", "v2"], + vocab_size: int, + tiktoken_num_special_tokens: int, + tiktoken_special_tokens: list[str] | None, + add_bos_token: bool = False, # nm5 does not use bos token + add_eos_token: bool = False, # nm5 does not use eos token + **unused_kwargs, + ): + assert "chat_template" not in unused_kwargs, ( + "We enforce the Nemotron5 chat template from the code, " + "please do not provide a chat_template in the tokenizer_config.json file" + ) + + pattern = PATTERN_TIKTOKEN if tiktoken_pattern == "v1" else PATTERN_TIKTOKEN_V2 + self._tokenizer = CustomTikTokenizer( + path=vocab_file, + pattern=pattern, + vocab_size=vocab_size, + num_special_tokens=tiktoken_num_special_tokens, + special_tokens=tiktoken_special_tokens, + ) + + eos_token = self._tokenizer.detokenize([self._tokenizer.eos]) + bos_token = self._tokenizer.detokenize([self._tokenizer.bos]) + self.vocab = self._tokenizer.vocab + super().__init__( + eos_token=AddedToken(eos_token, normalized=False, special=True), + bos_token=AddedToken(bos_token, normalized=False, special=True), + pad_token=AddedToken(eos_token, normalized=False, special=True), + ) + + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.chat_template = NEMOTRON5_CHAT_TEMPLATE + + self._vocab_file_contents = Path(vocab_file).read_text() + self._tokenizer_config = { + "tiktoken_pattern": tiktoken_pattern, + "vocab_size": vocab_size, + "tiktoken_num_special_tokens": tiktoken_num_special_tokens, + "tiktoken_special_tokens": tiktoken_special_tokens, + "add_bos_token": add_bos_token, + "add_eos_token": add_eos_token, + "tokenizer_class": "MegatronTikTokenizer", + "auto_map": { + "AutoTokenizer": ["tokenization_decilm.MegatronTikTokenizer", None], + }, + } + + def get_vocab(self) -> dict[str, int]: + """to satisfy PreTrainedTokenizer.__init__()""" + return self.vocab + + def tokenize(self, text: str, **kwargs) -> list[str]: + return [text] + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + is_single_token = isinstance(tokens, str) + if is_single_token: + text = tokens + else: + assert len(tokens) == 1 + text = tokens[0] + + ids = self._tokenizer._model.encode(text, allowed_special="all") + + if is_single_token: + assert len(ids) == 1, ( + f"Asked to convert a single token to its id, but it's not a single token: encode('{tokens}') = {ids}" + ) + return ids[0] + else: + return ids + + def convert_ids_to_tokens( + self, ids: int | list[int], skip_special_tokens: bool = False + ) -> str | list[str]: + is_single_id = isinstance(ids, int) + if is_single_id: + ids = [ids] + + if skip_special_tokens: + ids = [idd for idd in ids if idd not in (self.eos_token_id, self.bos_token_id)] + + text = self._tokenizer.detokenize(ids) + + if is_single_id: + return text + else: + return [text] + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """Taken from LlamaTokenizer""" + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def save_pretrained( + self, + save_directory: str | os.PathLike, + legacy_format: bool | None = None, + filename_prefix: str | None = None, + push_to_hub: bool = False, + **kwargs, + ) -> tuple[str, ...]: + assert legacy_format is None, "Unsupported" + assert filename_prefix is None, "Unsupported" + assert not push_to_hub, "Unsupported" + + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + tokenizer_config_path = save_directory / TOKENIZER_CONFIG_FILE + tokenizer_config_path.write_text(json.dumps(self._tokenizer_config, indent=2)) + + vocab_files_name = self.vocab_files_names["vocab_file"] + vocab_file_path = save_directory / vocab_files_name + vocab_file_path.write_text(self._vocab_file_contents) + + custom_object_save(self, save_directory) + + return str(tokenizer_config_path), str(vocab_file_path) + + +NEMOTRON5_CHAT_TEMPLATE = """{% if messages[0].role != "system" %} + {% set messages = [{"role": "system", "content": ""}] + messages %} +{% endif %} +{% for message in messages %} + {% if message.role == "system" %} +System +{{ message.content }} + {% elif message.role == "user" %} +User +{{ message.content }} + {% elif message.role == "assistant" %} +Assistant +{{ message.content }} + {% endif %} +{% endfor %} +{% if add_generation_prompt %} +Assistant +{% else %} + +{% endif %}""" From 6e09a81762131f1f662d3dfe1e507a6e98f19a94 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 4 Nov 2025 10:07:25 +0100 Subject: [PATCH 11/49] Add decilm code Signed-off-by: Daniel Korzekwa --- .../megatron_lm__mamba_mixer.py | 527 ++++++++++++++++++ 1 file changed, 527 insertions(+) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py new file mode 100644 index 000000000..76dbb3473 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py @@ -0,0 +1,527 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2024, Tri Dao, Albert Gu. + +# Adapted from megatron.core.ssm.mamba_mixer.MambaMixer: +# https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/0b5140009fb9011eceaef6d36ea1181a8d176479/megatron/core/ssm/mamba_mixer.py + +# ruff: noqa: N803, N806 + +# Some of this code was adopted from https://github.com/state-spaces/mamba/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + from einops import rearrange, repeat + from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import ( + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + ) + + class MambaMixerMegatron(nn.Module): + """ + Args: + d_model: The hidden size of the model. + d_state: The state size of the SSM. + d_conv: The number of channels in the causal convolution. + conv_init: The initialization range for the causal convolution weights. + nheads: The number of Mamba heads. Used to calculate the expansion factor for the SSM + instead of the deprecated arg "expand". + headdim: The hidden size of each attention head. + ngroups: The number of attention heads. + A_init_range: The initialization range for the attention weights. + D_has_hdim: Whether the D parameter has the same number of dimensions as the hidden + state. + rmsnorm: Whether to use root mean square normalization. + norm_before_gate: Whether to apply normalization before the gating mechanism. + dt_min: The minimum value of the dt parameter. + dt_max: The maximum value of the dt parameter. + dt_init: The initialization value of the dt parameter. + dt_scale: The scaling factor for the dt parameter. + dt_init_floor: The minimum value of the dt parameter after initialization. + bias: Whether to use bias in the linear layers. + conv_bias: Whether to use bias in the causal convolution. + chunk_size: The chunk size for the fused kernel. + use_mem_eff_path: Whether to use the memory-efficient path for the Mamba model. + layer_number: The layer number of this Mamba layer. + """ + + def __init__( + self, + d_model, + d_state=256, + d_conv=4, + conv_init=None, + nheads=256, + headdim=64, + ngroups=8, + A_init_range=(1, 16), + D_has_hdim=False, + rmsnorm=True, + norm_before_gate=False, + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + bias=False, + conv_bias=True, + # Fused kernel and sharding options + chunk_size=128, + use_mem_eff_path=True, + layer_number=None, + ): + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.conv_init = conv_init + self.nheads = nheads + self.headdim = headdim + self.ngroups = ngroups + self.D_has_hdim = D_has_hdim + self.rmsnorm = rmsnorm + self.norm_before_gate = norm_before_gate + self.chunk_size = chunk_size + self.use_mem_eff_path = use_mem_eff_path + self.layer_number = layer_number + + self.d_inner = self.nheads * self.headdim + + self.tensor_model_parallel_size = 1 + assert self.d_inner % self.tensor_model_parallel_size == 0 + assert self.ngroups % self.tensor_model_parallel_size == 0 + assert self.nheads % self.tensor_model_parallel_size == 0 + assert not bias + assert not self.norm_before_gate + + self.d_inner_local = self.d_inner // self.tensor_model_parallel_size + self.ngroups_local = self.ngroups // self.tensor_model_parallel_size + self.nheads_local = self.nheads // self.tensor_model_parallel_size + + assert self.d_inner_local % self.ngroups_local == 0 + + # Assume sequence parallelism: input is already partitioned along the + # sequence dimension + self.in_proj = nn.Linear( + self.d_model, + self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads, # AB CD E + bias=False, + ) + + conv_dim = self.d_inner_local + 2 * self.ngroups_local * self.d_state # A CD + + # weight dim: [conv_dim, conv_dim, d_conv] + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + ) + + if self.conv_init is not None: + nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) + + self.activation = "silu" + self.act = nn.SiLU() + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(self.nheads_local) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Our initialization would set all Linear.bias to zero, + # need to mark this one as _no_reinit + self.dt_bias._no_reinit = True + # Just to be explicit. Without this we already don't + # put wd on dt_bias because of the check + + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] + A = torch.empty(self.nheads_local, dtype=torch.float32).uniform_(*A_init_range) + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter( + torch.ones( + self.d_inner_local if self.D_has_hdim else self.nheads_local, + ) + ) # Keep in fp32 + self.D._no_weight_decay = True + + if self.rmsnorm: + self.norm = RMSNormGated( + self.d_inner_local, + eps=1e-5, + group_size=self.d_inner_local // self.ngroups_local, + norm_before_gate=self.norm_before_gate, + ) + + # Assume sequence parallelism: input is partitioned along d_inner and + # output is partitioned along the sequence dimension + self.out_proj = nn.Linear( + self.d_inner, + self.d_model, + bias=False, + ) + + def forward(self, hidden_states, inference_params=None): + """ + hidden_states: (nL, B, D) / (L B D) + Returns: same shape as hidden_states + """ + _, batch, dim = hidden_states.shape + + conv_state, ssm_state = None, None + if inference_params is not None: + # assert not self.config.sequence_parallel + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, out_bias, _, _ = self.step(hidden_states, conv_state, ssm_state) + return out, out_bias + + # (nheads_local) + A = -torch.exp(self.A_log.float()) + + # xz, _ = self.in_proj(hidden_states) # TransformerEngine also returns bias + xz = self.in_proj(hidden_states) + + # transpose: l b pd --> b l pd + xz = rearrange(xz, "l b d -> b l d").contiguous() + + if self.use_mem_eff_path and inference_params is None: + assert ssm_state is None + + if self.conv1d.bias is not None: + self.conv1d.bias.data_ptr() + + y = mamba_split_conv1d_scan_combined( + xz, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.dt_bias.float(), + A, + D=( + rearrange(self.D.float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.D + ), + chunk_size=self.chunk_size, + activation=self.activation, + headdim=None if self.D_has_hdim else self.headdim, + ngroups=self.ngroups_local, + norm_before_gate=self.norm_before_gate, + ) + + if self.rmsnorm: + y = self.norm(y) + else: + z, xBC, dt = torch.split( + xz, + [ + self.d_inner_local, + self.d_inner_local + 2 * self.ngroups_local * self.d_state, + self.nheads_local, + ], + dim=-1, + ) + + # transpose: b l pd --> b pd l + xBC = rearrange(xBC, "b l d -> b d l").contiguous() + + # Compute short convolution + if conv_state is not None: + # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + conv_state.copy_( + F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) + ) # Update state (B D W) + + seqlen = xBC.size(2) + if causal_conv1d_fn is None: + xBC = self.act(self.conv1d(xBC)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + xBC = causal_conv1d_fn( + x=xBC, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # transpose b pd l --> b l pd + xBC = rearrange(xBC, "b d l -> b l d").contiguous() + + x, B, C = torch.split( + xBC, + [ + self.d_inner_local, + self.ngroups_local * self.d_state, + self.ngroups_local * self.d_state, + ], + dim=-1, + ) + + # TO DO Vijay: fuse most of the transposes with the GEMMS + x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim).contiguous() + dt = dt.contiguous() + B = rearrange(B, "b l (g n) -> b l g n", n=self.d_state).contiguous() + C = rearrange(C, "b l (g n) -> b l g n", n=self.d_state).contiguous() + z = rearrange(z, "b l (h p) -> b l h p", p=self.headdim).contiguous() + y = mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + self.chunk_size, + D=( + rearrange(self.D.float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.D + ), + z=z if not self.rmsnorm else None, + dt_bias=self.dt_bias.float(), + dt_softplus=True, + return_final_states=ssm_state is not None, + ) + + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + + if self.rmsnorm: + y = rearrange(y, "b l h p -> b l (h p)").contiguous() + z = rearrange(z, "b l h p -> b l (h p)").contiguous() + y = self.norm(y, z) + else: + y = rearrange(y, "b l h p -> b l (h p)").contiguous() + + y = rearrange(y, "b l d -> l b d").contiguous() + # out, out_bias = self.out_proj(y) # TransformerEngine also returns bias + out = self.out_proj(y) + + return out + + def step(self, hidden_states, conv_state, ssm_state): + """ + Performs inference step for decoding + """ + # assert self.ngroups_local == 1, "Only support ngroups=1 for inference for now" + dtype = hidden_states.dtype + assert hidden_states.shape[0] == 1, ( + "Only support decoding with 1 token at a time for now" + ) + + # l b d --> b d + hidden_states = hidden_states.squeeze(0) + + # b d_model --> b p(2d) + xz, _ = self.in_proj(hidden_states) + + z, xBC, dt = torch.split( + xz, + [ + self.d_inner_local, + self.d_inner_local + 2 * self.ngroups_local * self.d_state, + self.nheads_local, + ], + dim=-1, + ) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum( + conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 + ) # (B D) + if self.conv1d.bias is not None: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(dtype=dtype) + else: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x, B, C = torch.split( + xBC, + [ + self.d_inner_local, + self.ngroups_local * self.d_state, + self.ngroups_local * self.d_state, + ], + dim=-1, + ) + A = -torch.exp(self.A_log.float()) + + # SSM step + if selective_state_update is None: + if self.ngroups_local > 1: + B = rearrange(B, "b (g n) -> b g n", n=self.d_state) + C = rearrange(C, "b (g n) -> b g n", n=self.d_state) + B = repeat(B, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local) + C = repeat(C, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local) + + dt = repeat(dt, "b h -> b (h p)", p=self.headdim) + dt_bias = repeat(self.dt_bias, "h -> (h p)", p=self.headdim) + A = repeat(A, "h -> (h p) n", p=self.headdim, n=self.d_state) + D = repeat(self.D, "h -> (h p)", p=self.headdim) + + dt = F.softplus(dt + dt_bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) + + dB_x = torch.einsum("bd,bdn,bd->bdn", dt, B, x) + ssm_state.copy_( + ssm_state * rearrange(dA, "b (h p) n -> b h p n", p=self.headdim) + + rearrange(dB_x, "b (h p) n -> b h p n", p=self.headdim) + ) + + y = torch.einsum( + "bdn,bdn->bd", + rearrange(ssm_state.to(dtype), "b h p n -> b (h p) n", p=self.headdim), + C, + ) + y = y + D.to(dtype) * x + if not self.rmsnorm: + y = y * self.act(z) # (B D) + else: + # Discretize A and B (b (g n)) + dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads) + dA = torch.exp(dt * A) + x = rearrange(x, "b (h p) -> b h p", p=self.headdim) + dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) + ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) + y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) + y = y + rearrange(self.D.to(dtype), "h -> h 1") * x + y = rearrange(y, "b h p -> b (h p)") + if not self.rmsnorm: + y = y * self.act(z) # (B D) + else: + A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) + dt = repeat(dt, "b h -> b h p", p=self.headdim) + dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) + D = repeat(self.D, "h -> h p", p=self.headdim) + B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local) + C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local) + x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) + if not self.rmsnorm: + z = rearrange(z, "b (h p) -> b h p", p=self.headdim) + y = selective_state_update( + ssm_state, + x_reshaped, + dt, + A, + B, + C, + D, + z=z if not self.rmsnorm else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + y = rearrange(y, "b h p -> b (h p)") + + if self.rmsnorm: + y = self.norm(y, z) + + # b pd --> b d + out, out_bias = self.out_proj(y) + return out.unsqueeze(0), out_bias, conv_state, ssm_state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + """ + allocate inference cache + """ + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.conv1d.weight.shape[0], + self.d_conv, + device=device, + dtype=conv_dtype, + ) + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + # ssm_dtype = torch.float32 + ssm_state = torch.zeros( + batch_size, + self.nheads_local, + self.headdim, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_number is not None + if self.layer_number not in inference_params.key_value_memory_dict: + conv_state = torch.zeros( + batch_size, + self.conv1d.weight.shape[0], + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.nheads_local, + self.headdim, + self.d_state, + device=self.in_proj.weight.device, + dtype=self.in_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[self.layer_number] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_number] + # TO DO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state + +except ImportError as exception: + mamba_error_message = f"Cannot declare MambaMixer due to missing dependencies: {exception=}." + warnings.warn(mamba_error_message) + + # TODO: Investigate why this type ignore is needed + class MambaMixerMegatron(nn.Module): # type: ignore[no-redef] + def __init__(self, *args, **kwargs): + raise ImportError(mamba_error_message) From 2e3f5da1d9f5428082f15d6af32b7c0b13ac93b1 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 4 Nov 2025 10:07:57 +0100 Subject: [PATCH 12/49] Add decilm code Signed-off-by: Daniel Korzekwa --- .../megatron_lm__megatron_tokenizer.py | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py new file mode 100644 index 000000000..1b3840a30 --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/megatron_lm__megatron_tokenizer.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Any + +import numpy + + +class MegatronTokenizer(ABC): + """Abstract class for tokenizer + + Absent a config or class-specific tracking of which objects are uniquely identifying, we must + include all key word arguments as unique identifiers + + Args: + tokenizer_paths (Tuple[str]): All tokenizer source paths or prefixes + + tokenizer_options (Dict[str, Any]): All tokenizer options + """ + + def __init__(self, *tokenizer_paths: str, **tokenizer_options: Any): + self.unique_identifiers = OrderedDict() + self.unique_identifiers["class"] = type(self).__name__ + self.unique_identifiers["tokenizer_path"] = list(tokenizer_paths) + for option in tokenizer_options: + self.unique_identifiers[option] = str(tokenizer_options[option]) + + self.unique_description = json.dumps(self.unique_identifiers, indent=4) + + super().__init__() + + @abstractmethod + def tokenize(self, text: str) -> numpy.ndarray: + """Convert text to embedding ids + + Args: + text (str): The text to convert + + Returns: + numpy.ndarray: The converted embedding ids + """ + + def detokenize(self, ids: numpy.ndarray) -> str: + """Convert embedding ids to text + + Args: + ids (numpy.ndarray): The ids to convert + + Returns: + str: The converted text + + Raises: + NotImplementedError: Non-abstract, optional method + """ + raise NotImplementedError("{} has no method 'detokenize'".format(type(self).__name__)) + + @property + @abstractmethod + def vocab(self): + """Dictionary from vocab text token to id token""" + + @property + @abstractmethod + def inv_vocab(self): + """Dictionary from vocab id token to text token""" + + @property + @abstractmethod + def vocab_size(self): + """The vocabulary size""" + + @property + def cls(self): + """The CLS token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'cls'".format(type(self).__name__)) + + @property + def sep(self): + """The SEP token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'sep'".format(type(self).__name__)) + + @property + def pad(self): + """The PAD token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'pad'".format(type(self).__name__)) + + @property + def eod(self): + """The EOD token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'eod'".format(type(self).__name__)) + + @property + def bos(self): + """The BOS token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'bos'".format(type(self).__name__)) + + @property + def eos(self): + """The EOS token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'eos'".format(type(self).__name__)) + + @property + def mask(self): + """The MASK token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'mask'".format(type(self).__name__)) From 418890e26101110cc5b875632a6251c85133202a Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 4 Nov 2025 11:25:48 +0100 Subject: [PATCH 13/49] Add decilm code Signed-off-by: Daniel Korzekwa --- .../decilm/deci_lm_hf_code/modeling_decilm.py | 2627 +++++++++++++++++ 1 file changed, 2627 insertions(+) create mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py new file mode 100644 index 000000000..808533d7f --- /dev/null +++ b/modelopt/torch/_compress/decilm/deci_lm_hf_code/modeling_decilm.py @@ -0,0 +1,2627 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2024 Nvidia Corporation, Google Inc, HuggingFace Inc, EleutherAI. All rights reserved. +# +# This code for Nvidia's model is based on the Llama modeling code by HuggingFace, +# which is in turn based on EleutherAI's GPT-NeoX library and the GPT-NeoX and +# OPT implementations in this library. +# Sliding window code based on Gemma2 by Google. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import math + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import GenerationConfig +from transformers.generation.utils import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from .block_config import AttentionConfig, FFNConfig, MambaConfig, MoEConfig +from .configuration_decilm import DeciLMConfig +from .megatron_lm__mamba_mixer import MambaMixerMegatron +from .transformers_4_44_2__activations import ACT2FN +from .transformers_4_44_2__cache_utils import Cache, StaticCache +from .transformers_4_44_2__modeling_attn_mask_utils import AttentionMaskConverter +from .transformers_4_44_2__modeling_flash_attention_utils_backward_compat import ( + _flash_attention_forward, +) +from .transformers_4_44_2__modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from .transformers_4_44_2__modeling_rope_utils import ROPE_INIT_FUNCTIONS +from .transformers_4_44_2__pytorch_utils import ALL_LAYERNORM_LAYERS +from .transformers_4_51_3__modeling_llama4_attention import Llama4TextAttention, Llama4TextConfig +from .variable_cache import VariableCache +from .vllm_yarn_utils import YaRNScalingRotaryEmbedding + +# from transformers.models.llama4.modeling_llama4 import Llama4TextL2Norm +MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[DeciLMConfig.model_type] = "DeciLMForCausalLM" +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeciLMConfig" + + +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or + a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be + as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class Llama4TextL2Norm(torch.nn.Module): + def __init__(self, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self._norm(x.float()).type_as(x) + + def extra_repr(self): + return f"eps={self.eps}" + + +class DeciLMRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeciLMRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(DeciLMRMSNorm) + + +class DeciLMRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: DeciLMConfig | None = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`DeciLMRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_impl = "rope" if config is None else config.position_embedding_type + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + def _set_inv_freq_if_needed(self, device: torch.device) -> None: + is_missing_inv_freq = not hasattr(self, "inv_freq") + is_meta_mismatch = not is_missing_inv_freq and ( + str(device) != "meta" and self.inv_freq.is_meta + ) + + if is_missing_inv_freq or is_meta_mismatch: + with torch.device(device): + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, **self.rope_kwargs + ) + self.original_inv_freq = inv_freq + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer( + "inv_freq", inv_freq, persistent=False + ) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if ( + seq_len < self.original_max_seq_len + and self.max_seq_len_cached > self.original_max_seq_len + ): # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + self._set_inv_freq_if_needed(x.device) + + if self.rope_impl == "rope_llama4": + return self.llama4_forward(x, position_ids) + else: + return self.llama3_forward(x, position_ids) + + def llama3_forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def llama4_forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + # Core RoPE block + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) + freqs_cis = torch.polar( + torch.ones_like(freqs), freqs + ) # Convert to complex representation + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + freqs_cis = freqs_cis * self.attention_scaling + return freqs_cis + + +class DeciMistralYarnRotaryEmbedding(nn.Module): + def __init__(self, config: DeciLMConfig): + super().__init__() + self.config = config + self.rope_scaling = config.rope_scaling + self.base = config.rope_theta + self.rope_impl = config.position_embedding_type + self.head_size = config.hidden_size // config.num_attention_heads + self.yarn = YaRNScalingRotaryEmbedding( + head_size=self.head_size, + rotary_dim=self.head_size, + max_position_embeddings=self.rope_scaling["original_max_position_embeddings"], + base=self.base, + is_neox_style=True, + scaling_factor=self.rope_scaling["factor"], + beta_fast=self.rope_scaling["beta_fast"], + beta_slow=self.rope_scaling["beta_slow"], + dtype=torch.float32, + ) + self.attention_scaling = self.yarn.mscale + self.scaling_factor = self.rope_scaling["factor"] + self.rope_impl = "rope" if config is None else config.position_embedding_type + self.rope_impl = "even_odd" + + def _set_inv_freq_if_needed(self, device: torch.device) -> None: + is_missing_inv_freq = not hasattr(self, "inv_freq") + is_meta_mismatch = not is_missing_inv_freq and ( + str(device) != "meta" and self.inv_freq.is_meta + ) + + if is_missing_inv_freq or is_meta_mismatch: + with torch.device(device): + inv_freq = self.yarn._compute_inv_freq(self.scaling_factor) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def halves_forward(self, x, position_ids): + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + + self._set_inv_freq_if_needed(x.device) + + # print(f"halves_forward") + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + inv_freq_expanded = inv_freq_expanded.to(x.device) + # print(f"inv_freq_expanded: {inv_freq_expanded.device}") + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def forward(self, x, position_ids): + if self.rope_impl == "halves": + return self.halves_forward(x, position_ids) + elif self.rope_impl == "even_odd": + return self.even_odd_forward(x, position_ids) + else: + raise ValueError(f"Invalid rope implementation: {self.rope_impl}") + + def even_odd_forward(self, x, position_ids): + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + + self._set_inv_freq_if_needed(x.device) + + # print(f"even_odd_forward") + # Core RoPE block + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) + freqs_cis = torch.polar( + torch.ones_like(freqs), freqs + ) # Convert to complex representation + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + freqs_cis = freqs_cis * self.attention_scaling + return freqs_cis + + +class DeciLMLinearScalingRotaryEmbedding(DeciLMRotaryEmbedding): + """DeciLMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`DeciLMLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`DeciLMRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." + ) + kwargs["rope_type"] = "linear" + super().__init__(*args, **kwargs) + + +class DeciLMDynamicNTKScalingRotaryEmbedding(DeciLMRotaryEmbedding): + """DeciLMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`DeciLMDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " + "`DeciLMRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." + ) + kwargs["rope_type"] = "dynamic" + super().__init__(*args, **kwargs) + + +rope_type_to_class = { + "default": DeciLMRotaryEmbedding, + "linear": DeciLMLinearScalingRotaryEmbedding, + "dynamic": DeciLMDynamicNTKScalingRotaryEmbedding, + "rope_llama4": DeciLMRotaryEmbedding, + "rope": DeciLMRotaryEmbedding, + "mistral_yarn": DeciMistralYarnRotaryEmbedding, +} + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, freqs_cis, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + freqs_cis (`torch.Tensor`): The frequency tensor. + a tuple of two tensors, cos and sin. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # print(f"applying first half-second half") + cos, sin = freqs_cis + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def vllm_apply_rotary_emb_torch( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + # print(f"freqs_cis: {freqs_cis.shape}, xq_: {xq_.shape}, xk_: {xk_.shape}") + xq_out = torch.view_as_real(xq_ * freqs_cis[:, None, :, :]).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis[:, None, :, :]).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class DeciLMGatedMLP(nn.Module): + def __init__( + self, + config: DeciLMConfig, + ffn_config: FFNConfig, + ): + super().__init__() + self.config = config + self.ffn_config = ffn_config + self.hidden_size = config.hidden_size + self.intermediate_size = ffn_config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[ffn_config.hidden_act] + + if ffn_config.sparsify is not None: + self.register_full_backward_hook(sparsity_backward_hook) + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], + dim=-1, + ) + up_proj = torch.cat( + [F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +class DeciLMVanillaMLP(nn.Module): + def __init__( + self, + config: DeciLMConfig, + ffn_config: FFNConfig, + ): + super().__init__() + self.config = config + self.ffn_config = ffn_config + self.hidden_size = config.hidden_size + self.intermediate_size = ffn_config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[ffn_config.hidden_act] + + if ffn_config.sparsify is not None: + self.register_full_backward_hook(sparsity_backward_hook) + + assert self.config.pretraining_tp == 1, ( + "Unsupported pretraining_tp != 1 for DeciLMVanillaMLP" + ) + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class DeciLMAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: DeciLMConfig, + attention_config: AttentionConfig, + layer_idx: int | None = None, + ): + super().__init__() + self.config = config + self.attention_config = attention_config # type: AttentionConfig + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + if config.head_dim is not None: + self.head_dim = config.head_dim + else: + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_groups = attention_config.n_heads_in_group # DeciLM-specific code + self.num_key_value_heads = ( + self.num_heads // self.num_key_value_groups + ) # DeciLM-specific code + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + # llama4 attention specific + self.llama4_attn_config = attention_config.llama4 + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.o_proj_bias + ) + + if self.config.position_embedding_type in ["rope", "rope_llama4", "mistral_yarn"]: + # TO DO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = rope_type_to_class[self.config.position_embedding_type]( + config=self.config + ) + + if attention_config.sparsify is not None: + self.register_full_backward_hook(sparsity_backward_hook) + + self.is_llama4 = self.llama4_attn_config is not None + if ( + self.is_llama4 + and self.llama4_attn_config.use_qk_norm + and self.llama4_attn_config.use_rope + ): + self.qk_norm = Llama4TextL2Norm(self.config.rms_norm_eps) + + self.use_rope = ( + self.llama4_attn_config.use_rope + if self.is_llama4 + else self.config.position_embedding_type in ["rope", "mistral_yarn"] + ) + self.rope_impl = self.rotary_emb.rope_impl + self.apply_rope_fn = ( + apply_rotary_emb + if self.rope_impl in ["even_odd", "rope_llama4"] + else apply_rotary_pos_emb + ) + # self.apply_rope_fn = apply_rotary_emb + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # will become mandatory in v4.45 + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + + if self.config.pretraining_tp > 1: + key_value_slicing = ( + self.num_key_value_heads * self.head_dim + ) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( + 1, 2 + ) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + if self.use_rope: + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE " + "embeddings internally through `position_ids` (2D tensor with the indexes of the " + "tokens), to using externally computed `position_embeddings` (Tuple of tensors, " + "containing cos and sin). In v4.45 `position_ids` will be removed and " + "`position_embeddings` will be mandatory." + ) + freqs_cis = self.rotary_emb(value_states, position_ids) + else: + freqs_cis = position_embeddings + + query_states, key_states = self.apply_rope_fn(query_states, key_states, freqs_cis) + + if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm + query_states = self.qk_norm(query_states) + key_states = self.qk_norm(key_states) + + if self.is_llama4: + query_states = self.apply_attention_scaling(input_shape, cache_position, query_states) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + # print(f"cache_position: {cache_position}") + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split( + self.hidden_size // self.config.pretraining_tp, dim=1 + ) + attn_output = sum( + [ + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def apply_attention_scaling(self, input_shape, cache_position, query_states): + # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers + if self.llama4_attn_config.attn_temperature_tuning and not self.use_rope: + attn_scales = ( + torch.log( + torch.floor( + (cache_position.float() + 1.0) / self.llama4_attn_config.floor_scale + ) + + 1.0 + ) + * self.llama4_attn_config.attn_scale + + 1.0 + ) + attn_scales = attn_scales.view((*input_shape, 1, 1)).transpose(1, 2) + query_states = (query_states * attn_scales).to(query_states.dtype) + return query_states + return query_states + + +class DeciLMFlashAttention2(DeciLMAttention): + """ + DeciLM flash attention module. This module inherits from `DeciLMAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is + # bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is + # used to handle this difference. + # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case + # q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + self.sliding_window = self.attention_config.prefill_sliding_window + + self.pre_attention_identity_query = nn.Identity() # for debugging hooks + self.pre_attention_identity_key = nn.Identity() # for debugging hooks + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # will become mandatory in v4.45 + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( + 1, 2 + ) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + if self.config.position_embedding_type in ["rope", "mistral_yarn"]: + # llama4 doesn't use flash attention + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE " + "embeddings internally through `position_ids` (2D tensor with the indexes of the " + "tokens), to using externally computed `position_embeddings` (Tuple of tensors, " + "containing cos and sin). In v4.45 `position_ids` will be removed and " + "`position_embeddings` will be mandatory." + ) + freqs_cis = self.rotary_emb(value_states, position_ids) + else: + freqs_cis = position_embeddings + + query_states, key_states = self.apply_rope_fn(query_states, key_states, freqs_cis) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, freq_cis) + # print(f"applying even odd rope") + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV + # cache to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeciLMRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + query_states = self.pre_attention_identity_query(query_states) + key_states = self.pre_attention_identity_key(key_states) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=self.sliding_window, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +DECILM_ATTENTION_CLASSES = { + "eager": DeciLMAttention, + "flash_attention_2": DeciLMFlashAttention2, +} + + +class DeciLMLlama4TextAttention(Llama4TextAttention): + def __init__(self, config: DeciLMConfig, layer_idx: int, attention_config: AttentionConfig): + llama4_text_config = Llama4TextConfig( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_attention_heads // attention_config.n_heads_in_group, + head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), + attn_scale=attention_config.llama4.attn_scale, + floor_scale=attention_config.llama4.floor_scale, + attn_temperature_tuning=attention_config.llama4.attn_temperature_tuning, + attention_dropout=attention_config.llama4.attention_dropout, + use_qk_norm=attention_config.llama4.use_qk_norm, + use_rope=attention_config.llama4.use_rope, + rms_norm_eps=config.rms_norm_eps, + attention_bias=config.attention_bias, + attn_implementation=config.llama4_attn_implementation, + rope_scaling=config.rope_scaling, + max_position_embeddings=config.max_position_embeddings, + attention_chunk_size=attention_config.llama4.attention_chunk_size, + ) + super().__init__(llama4_text_config, layer_idx, use_rope=attention_config.llama4.use_rope) + + +class DeciLMDecoderLayer(nn.Module): + # DeciLM-specific code + def __init__(self, config: DeciLMConfig, layer_idx: int | tuple[int, ...]): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.block_config = config.get_block_config(layer_idx) + + self.attention_config = self.block_config.attention + self.ffn_config = self.block_config.ffn + self.layer_idx = layer_idx + + if not self.attention_config.no_op: + self.input_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.attention_config.replace_with_linear: + self.self_attn = DeciLMLinearAttention(config) + elif self.attention_config.is_mamba: + self.self_attn = DeciLMMambaMixer(config, self.attention_config.mamba) + elif not self.attention_config.is_llama4: + self.self_attn = DECILM_ATTENTION_CLASSES[config._attn_implementation]( + config=config, attention_config=self.attention_config, layer_idx=layer_idx + ) + else: + self.self_attn = DeciLMLlama4TextAttention(config, layer_idx, self.attention_config) + + if not (self.ffn_config.no_op or self.attention_config.is_mamba): + if self.ffn_config.hidden_act is None: + print(f"WARNING: FFN hidden_act is None for layer {layer_idx}") + + self.post_attention_layernorm = DeciLMRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + if self.ffn_config.replace_with_linear: + self.mlp = DeciLMLinearMLP(config) + elif self.ffn_config.is_moe: + self.mlp = DeciLMMoe(config, self.ffn_config) + else: + self.mlp = ( + DeciLMGatedMLP(config, self.ffn_config) + if self.ffn_config.gated + else DeciLMVanillaMLP(config, self.ffn_config) + ) + + self.is_sliding = self.attention_config.is_sliding + self.sliding_window = self.attention_config.prefill_sliding_window + self.return_only_hidden_states = self.config.block_return_only_hidden_states + + @property + def device(self): + try: + return next(self.parameters()).device + except StopIteration: + return None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool | None = False, + output_router_logits: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC + **kwargs, + ) -> ( + tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None] + | torch.FloatTensor + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + paramz = list(self.parameters()) + device = paramz[0].device if len(paramz) > 0 else None + if isinstance(hidden_states, tuple): + # could happen when sewing kit sends the output of the previous layer + # to this layer without going through the model forward unpacking code. + # can be avoided by using config.block_return_only_hidden_states=True + hidden_states = hidden_states[0] + + hidden_states = hidden_states.to(device) + + if cache_position is not None: + cache_position = cache_position.to(device) + + if self.attention_config.llama4 is not None: + # chunk_size = self.attention_config.llama4.attention_chunk_size + # print(f"pre-llama4_update: {attention_mask=}") + # causal_mask, chunk_causal_mask = self._llama4_update_causal_mask( + # attention_mask, hidden_states, cache_position, past_key_value, output_attentions, use_cache=use_cache, + # ) + # attention_mask = causal_mask if (chunk_size is None) else chunk_causal_mask + # if (past_key_value is not None) and isinstance(attention_mask, BlockMask): + # print(f"pre-adjust: {attention_mask.shape=}") + # print(f"pre-adjust: {hidden_states.shape=}") + # print(f"pre-adjust: {past_key_value.get_seq_length()=}") + # q_len = hidden_states.shape[1] + # kv_len = past_key_value.get_seq_length() + # if kv_len == 0: + # kv_len = q_len + # print(f"pre-adjust: {kv_len=} {q_len=}") + # print(f"post-adjust: {attention_mask.shape=}") + assert self.config.llama4_attn_implementation != "flex_attention", ( + "We have a mask issue with flex attention" + ) + + causal_mask, chunk_causal_mask = self._llama4_update_causal_mask( + attention_mask, + hidden_states, + cache_position, + past_key_value, + output_attentions, + use_cache=use_cache, + ) + is_chunked = self.attention_config.llama4.attention_chunk_size is not None + attention_mask = ( + chunk_causal_mask if is_chunked and (chunk_causal_mask is not None) else causal_mask + ) + + else: + attention_mask = self._llama3_update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_value, output_attentions + ) + if self.attention_config.unshifted_sink and self.attention_config.is_sink: + attention_mask = self._unshifted_sink_mask( + attention_mask, + hidden_states, + self.attention_config.window_length, + self.attention_config.num_sink_tokens, + ) + else: + attention_mask = self._gemma2_window_mask( + attention_mask, hidden_states, past_key_value + ) + + self_attn_weights = None + present_key_value = past_key_value + router_logits = None + + if self.attention_config.no_op: + pass + elif self.attention_config.replace_with_linear or self.attention_config.is_mamba: + if self.attention_config.is_mamba: + assert past_key_value is None, "DeciLM does not support generation with Mamba yet" + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states) + hidden_states = residual + hidden_states + else: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_out = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states, self_attn_weights = attn_out[:2] + if len(attn_out) > 2: + present_key_value = attn_out[2] + + hidden_states = residual + hidden_states + + if not self.ffn_config.no_op: + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # Handle MoE layers differently as they return router logits + if self.ffn_config.is_moe: + hidden_states, router_logits = self.mlp(hidden_states) + else: + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + if self.return_only_hidden_states: + return hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits and router_logits is not None: + outputs += (router_logits,) + + return outputs + + def _gemma2_window_mask( + self, + attention_mask: torch.Tensor | None, + hidden_states: torch.Tensor, + past_key_value: VariableCache | None, + ) -> torch.Tensor | None: + if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding + # Flash-attn is a 2D tensor + if self.config._attn_implementation == "flash_attention_2": + if past_key_value is not None: # when decoding + attention_mask = attention_mask[:, -self.sliding_window :] + else: + min_dtype = torch.finfo(hidden_states.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + if attention_mask.shape[-1] <= 1: # when decoding + attention_mask = attention_mask[:, :, :, -self.sliding_window :] + return attention_mask + + def _unshifted_sink_mask( + self, + attention_mask: torch.Tensor, + hidden_states: torch.Tensor, + window_length: int, + num_sink_tokens: int | None, + ) -> torch.Tensor: + assert self.config._attn_implementation == "eager", ( + "Unshifted sink is only supported in 'eager' mode." + ) + assert attention_mask is not None, "The attention mask seems to not be prepared" + + attention_mask = attention_mask.clone() + min_dtype = torch.finfo(hidden_states.dtype).min + + if window_length == 0: + attention_mask = torch.full_like(attention_mask, fill_value=min_dtype) + else: + query_length = attention_mask.shape[-2] + is_decode = query_length == 1 + if is_decode: + attention_mask[:, :, :, :-window_length] = min_dtype + else: + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-window_length + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + + attention_mask[:, :, :, :num_sink_tokens] = 0 + return attention_mask + + def _llama3_update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is + # 2D and of dynamic length even when the static KV cache is used. This is an issue for + # torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic + # shapes. (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. + # A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`. + # See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache" + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not using_static_cache + and not output_attentions + ): + if ( + AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ) + and not self.is_sliding + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @torch.compiler.disable(recursive=False) # the operations in this method are not compilable + def _llama4_update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache | None, + output_attentions: bool = False, + chunked_attention_mask=None, + use_cache=True, + ): + attn_implementation = self.config.llama4_attn_implementation + + if attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return ( + attention_mask, + attention_mask, + ) # flash does not support chunked attn TODO support flash + return None, None + + if attn_implementation not in ["sdpa", "flex_attention", "eager"]: + return None, None + + sequence_length = input_tensor.shape[1] + cache_position = cache_position.to(self.device) + attention_chunk_size = self.attention_config.llama4.attention_chunk_size + if attention_chunk_size is None: + # let the function build some chunked mask, we won't use it since it's not a chunked + # attention layer. We still need to know the chunk size for this if statement that + # comes later on: if attn_implementation == "sdpa" and chunked_attention_mask is not None + # otherwise the mask dtype is wrong for sdpa :bufo-wat: + attention_chunk_size = self.config.get_min_attention_chunk_size() + if attention_chunk_size is None: + logger.warning_once( + "Could not infer attention_chunk_size since the model (or the model shard) " + "has no chunked attention, using 8192 as default for mask construction" + ) + attention_chunk_size = 8192 + + first_cache_position = cache_position[0] + + if past_key_values is not None: + full_cache_length = past_key_values.get_max_cache_shape() or sequence_length + else: + full_cache_length = ( + attention_mask.shape[-1] if attention_mask is not None else sequence_length + ) + + cond1 = first_cache_position >= attention_chunk_size + cond2 = (first_cache_position < attention_chunk_size) & ( + first_cache_position + sequence_length > attention_chunk_size + ) + key_length = ( + torch.where( + cond1, + attention_chunk_size + sequence_length - 1, + torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), + ) + if use_cache + else full_cache_length + ) + + if attn_implementation == "flex_attention": + raise NotImplementedError("DeciLM Llama4 does not support flex attention") + # if isinstance(attention_mask, torch.Tensor): + # offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0)) + # chunked_attention_mask = make_flex_block_causal_mask( + # attention_mask, attention_chunk_size, sequence_length, key_length, offsets=offsets + # ) + # attention_mask = make_flex_block_causal_mask( + # attention_mask, + # query_length=sequence_length, + # key_length=full_cache_length, + # offsets=(first_cache_position, 0), + # ) + # return attention_mask, chunked_attention_mask + # if isinstance(attention_mask, BlockMask): + # return attention_mask, chunked_attention_mask + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + dtype, device = input_tensor.dtype, input_tensor.device + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=max(full_cache_length, attention_chunk_size), + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + min_dtype=torch.finfo(dtype).min, + ) + if full_cache_length > attention_chunk_size: + start_idx = max(first_cache_position - attention_chunk_size + 1, 0) + end_idx = start_idx + key_length + chunked_attention_mask = self.create_chunked_attention_mask( + attention_chunk_size, + start=start_idx, # same offset as with flex + end=end_idx, + device=device, + ) + + ### Deci: we added this code to patch a bug in transformers + if attention_mask is None: + if past_key_values is not None: + raise NotImplementedError("We only support attention_mask=None is prefill") + attention_mask = torch.ones( + input_tensor.shape[0], input_tensor.shape[1], device=device, dtype=torch.long + ) + + local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well + # It may be smaller than attention_chunk_size -> pad it + requires_padding = local_attention_mask.shape[-1] < attention_chunk_size + if requires_padding: + local_attention_mask = nn.functional.pad( + local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1]) + ) + # Depending on the padding, take the query tokens from the end or the cache_position + if not requires_padding: + chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :] + else: + chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :] + + chunked_attention_mask = chunked_attention_mask.expand( + input_tensor.shape[0], -1, -1, -1 + ) + chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :] + if attn_implementation == "eager": + min_dtype = torch.finfo(dtype).min + chunked_attention_mask = torch.where( + chunked_attention_mask == 0, min_dtype, 0.0 + ).to(dtype) + + # print(f"{output_attentions=}") + + if ( + attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and attention_mask.ndim == 4 + and not output_attentions # Only unmask for 4d masks + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if attn_implementation == "sdpa" and chunked_attention_mask is not None: + chunked_attention_mask = chunked_attention_mask.bool() + causal_mask = causal_mask.bool() + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=first_cache_position, + is_training=self.training, + ): + causal_mask = None + return causal_mask, chunked_attention_mask + + def create_chunked_attention_mask( + self, attention_chunk_size: int, start: int, end: int, device: torch.device + ) -> torch.Tensor: + """ + Generate the following: + + 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ | + '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ | + '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ | + 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ | + '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ | + '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ | + + If the chunk size is 3. + This can just be appplied over the already created attention mask + """ + arange_vector = torch.arange(start, end, device=device) + block_pos = torch.abs( + arange_vector.unsqueeze(0) // attention_chunk_size + - arange_vector.unsqueeze(1) // attention_chunk_size + ) + token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1) + mask = (block_pos == 0) & (token_pos <= 0) + return mask.to(device) + + +class DeciLMMultiDecoderLayer(nn.Module): + def __init__(self, config: DeciLMConfig, layer_idx: int): + super().__init__() + self.config = config + block_config = config.block_configs[layer_idx] + assert block_config.parallel_blocks is not None + num_parallel_blocks = len(block_config.parallel_blocks) + self.parallel_blocks = nn.ModuleList( + [ + DeciLMDecoderLayer(config, (layer_idx, internal_block_idx)) + for internal_block_idx in range(num_parallel_blocks) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + *args, + **kwargs, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + block_outputs = [block(hidden_states, *args, **kwargs) for block in self.parallel_blocks] + output_hidden_states = [ + out[0].to(hidden_states.device) + if isinstance(out, tuple) + else out.to(hidden_states.device) + for out in block_outputs + ] + output_hidden_states = torch.stack(output_hidden_states, dim=0).sum(dim=0) + output_hidden_states = ( + output_hidden_states - (len(self.parallel_blocks) - 1) * hidden_states + ) + + if self.config.block_return_only_hidden_states: + return output_hidden_states + + other_outputs = block_outputs[0][1:] + outputs = (output_hidden_states, *other_outputs) + return outputs + + +DECILM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeciLMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeciLM Model outputting raw hidden-states without any specific head on top.", + DECILM_START_DOCSTRING, +) +class DeciLMPreTrainedModel(PreTrainedModel): + config_class = DeciLMConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeciLMDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True # all the _supports_... flags refer to the Llama3 layers + _supports_sdpa = False + _supports_flex_attn = False + _supports_cache_class = True + _supports_quantized_cache = False + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _prepare_generation_config( + self, + generation_config: GenerationConfig | None, + *args, + **kwargs, + ) -> tuple[GenerationConfig, dict]: + try: + from transformers import cache_utils + from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING + + need_setup_cache_classes_mapping = NEED_SETUP_CACHE_CLASSES_MAPPING + except Exception: + # older releases exposed it via generation.utils + need_setup_cache_classes_mapping = {} + + # DeciLM-specific code + generation_config, model_kwargs = super()._prepare_generation_config( + generation_config, *args, **kwargs + ) + # New transformers version, can reach only through cache_utils + if need_setup_cache_classes_mapping == {}: + cache_utils._CACHE_IMPLEMENTATION_MAPPING["variable"] = VariableCache + else: + need_setup_cache_classes_mapping["variable"] = VariableCache + + generation_config.cache_implementation = "variable" + return generation_config, model_kwargs + + +DECILM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`VariableCache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + If passed to the forward function, past_key_values must be a VariableCache object (see imports). + For generation purposes, this is already handled inside model.generate(). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare DeciLM Model outputting raw hidden-states without any specific head on top.", + DECILM_START_DOCSTRING, +) +class DeciLMModel(DeciLMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeciLMDecoderLayer`] + + Args: + config: DeciLMConfig + """ + + def __init__(self, config: DeciLMConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + ( + DeciLMDecoderLayer(config, layer_idx) + if (config.block_configs[layer_idx].parallel_blocks is None) + else DeciLMMultiDecoderLayer(config, layer_idx) + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.config.position_embedding_type in ["rope", "rope_llama4", "mistral_yarn"]: + self.rotary_emb = rope_type_to_class[self.config.position_embedding_type](config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_final_layer_norm(self): + return self.norm + + def set_final_layer_norm(self, value): + self.norm = value + + @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + output_router_logits: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple | BaseModelOutputWithPast: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + is_legacy_cache_format = (past_key_values is not None) and type( + past_key_values + ).__name__ != "VariableCache" + # We use the __name__ instead of isinstance to support weird use cases + # (init cache from a checkpoint dir and use it with local code) + if is_legacy_cache_format: + raise NotImplementedError( + "DeciLMModel does not support legacy cache format, please use a newer " + "transformers version or use VariableCache explicitly (see import in this file)." + ) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + # use default device + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = None + if hasattr(self, "rotary_emb"): + # rotary emb is created all devices, so we need to move position_ids to the correct device + some_param = next(self.parameters()) + position_ids = position_ids.to(some_param.device) + cache_position = cache_position.to(some_param.device) + faux_hidden_states = position_ids.to(some_param.dtype) + position_embeddings = self.rotary_emb(faux_hidden_states, position_ids) + # print(f'START {position_embeddings.device=}') # HF hook will change the device + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + if self.config.block_return_only_hidden_states: + hidden_states = layer_outputs + next_decoder_cache = past_key_values + + else: + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Extract router logits if they exist + if output_router_logits: + router_logits_index = -1 # Router logits are always the last element + if len(layer_outputs) > (2 if output_attentions else 1) + ( + 1 if use_cache else 0 + ): + all_router_logits += (layer_outputs[router_logits_index],) + + # Final layer norm + hidden_states = hidden_states.to(next(self.parameters()).device) + hidden_states = self.norm(hidden_states) + + # Add the last hidden state + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # Set the next cache + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + outputs = (hidden_states, next_cache, all_hidden_states, all_self_attns) + if output_router_logits: + outputs += (all_router_logits,) + return outputs + + # Handle different return types based on whether router logits are requested + if output_router_logits and all_router_logits: + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + else: + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@add_start_docstrings( + """ + The DeciLM Model transformer with a sequence classification head on top (linear layer). + + [`DeciLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DECILM_START_DOCSTRING, +) +class DeciLMForSequenceClassification(DeciLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeciLMModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + elif input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits, *transformer_outputs[1:]) + return (loss, *output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The DeciLM Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DECILM_START_DOCSTRING, +) +class DeciLMForQuestionAnswering(DeciLMPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->DeciLM + def __init__(self, config): + super().__init__(config) + self.transformer = DeciLMModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.FloatTensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + start_positions: torch.LongTensor | None = None, + end_positions: torch.LongTensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | QuestionAnsweringModelOutput: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits, *outputs[2:]) + return (total_loss, *output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The DeciLM Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + DECILM_START_DOCSTRING, +) +class DeciLMForTokenClassification(DeciLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeciLMModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, *outputs[2:]) + return (loss, *output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +######################################################################## +# DeciLM-specific code +######################################################################## + + +def _find_multiple(n: int, k: int) -> int: + # DeciLM-specific code + if n % k == 0: + return n + return n + k - (n % k) + + +class DeciLMMoe(nn.Module): + """ + Implementation of Mixture of Experts module for DeciLM. + Equivalent to Llama4 MoE but implemented more frugally. + """ + + def __init__(self, config: DeciLMConfig, ffn_config: FFNConfig): + super().__init__() + self.config = config + self.ffn_config = ffn_config + + # MoE parameters + assert ffn_config.moe is not None, "MoE configuration must be provided to use DeciLMMoe" + self.moe_config: MoEConfig = ffn_config.moe + self.hidden_dim = config.hidden_size + self.num_experts_per_tok = self.moe_config.num_experts_per_tok + self.num_local_experts = self.moe_config.num_local_experts + self.expert_intermediate_dim = self.moe_config.expert_intermediate_dim + self.shared_expert_intermediate_dim = self.moe_config.shared_expert_intermediate_dim + + # Initialize experts and router + routed_expert_ffn_config = FFNConfig( + intermediate_size=self.expert_intermediate_dim, + ) + + self.experts = nn.ModuleList( + [ + DeciLMGatedMLP(config, routed_expert_ffn_config) + for _ in range(self.num_local_experts) + ] + ) + + self.router = nn.Linear(config.hidden_size, self.num_local_experts, bias=False) + + # Initialize shared expert as a standard MLP + shared_expert_ffn_config = FFNConfig( + intermediate_size=self.moe_config.shared_expert_intermediate_dim + ) + self.shared_expert = DeciLMGatedMLP(config, shared_expert_ffn_config) + + if ffn_config.sparsify is not None: + self.register_full_backward_hook(sparsity_backward_hook) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass through the MoE layer. + + Args: + hidden_states (torch.Tensor): Input tensor of shape (batch, seq_len, hidden_dim) + + Returns: + tuple: + - torch.Tensor: Output tensor of shape (batch, seq_len, hidden_dim) + - torch.Tensor: Router scores for loss computation + """ + router_logits = self.router(hidden_states) + + routed_out = self.forward_routed_experts(hidden_states, router_logits) + + shared_out = self.shared_expert(hidden_states) + + moe_out = routed_out + shared_out + + return moe_out, router_logits + + def forward_routed_experts( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor + ) -> torch.Tensor: + """ + For each expert: + 1. Build the input to the expert based on the router mask + 2. Run the expert + 3. Add the result of the expert into the total MoE result using += + """ + router_top_values, router_indices = torch.topk( + router_logits, self.num_experts_per_tok, dim=-1 + ) + router_scores = torch.sigmoid(router_top_values.float()).to(hidden_states.dtype) + + routed_out = torch.zeros_like(hidden_states) + for i_expert in range(self.num_local_experts): + expert_mask = router_indices == i_expert + if expert_mask.any(): + is_token_routed_to_this_expert = expert_mask.any(dim=-1) + relevant_hidden_states = hidden_states[is_token_routed_to_this_expert, :] + relevant_scores = router_scores[expert_mask] + expert_in = relevant_hidden_states * relevant_scores.unsqueeze(-1) + + expert_out = self.experts[i_expert](expert_in).to(hidden_states.device) + + routed_out[is_token_routed_to_this_expert, :] += expert_out + + return routed_out + + def extra_repr(self) -> str: + return ( + f"(MoE): num_local_experts={self.num_local_experts}, " + f"expert_intermediate_dim={self.expert_intermediate_dim}," + ) + + +class DeciLMLinearMLP(nn.Module): + # DeciLM-specific code + def __init__( + self, + config: DeciLMConfig, + ): + super().__init__() + self.linear_mlp = nn.Linear( + in_features=config.hidden_size, out_features=config.hidden_size, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_mlp.forward(x) + + +class DeciLMLinearAttention(nn.Module): + # DeciLM-specific code + def __init__( + self, + config: DeciLMConfig, + ): + super().__init__() + self.linear_attn = nn.Linear( + in_features=config.hidden_size, out_features=config.hidden_size, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_attn.forward(x) + + +def sparsity_backward_hook(*args, **kwargs): + raise NotImplementedError( + "No support for sparsity when training HF DeciLM (inference is ok though)" + ) + + +class DeciLMMambaMixer(nn.Module): + def __init__( + self, + config: DeciLMConfig, + mamba_config: MambaConfig, + ): + super().__init__() + self.mamba_mixer = MambaMixerMegatron( + d_model=config.hidden_size, + d_state=mamba_config.state_dim, + nheads=mamba_config.num_heads, + headdim=mamba_config.head_dim, + ngroups=mamba_config.num_groups, + ) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + x = x.permute([1, 0, 2]) # MambaMixerMegatron expects [Sequence, Batch, Embedding] + out = self.mamba_mixer(x) + out = out.permute([1, 0, 2]) # go back to [Batch, Sequence, Embedding] + return out + + +class LMHead(nn.Linear): + """ + Special class to allow FSDP wrapping without affecting other Linear layers in the model. + """ + + +class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: DeciLMConfig): + super().__init__(config) + self.model = DeciLMModel(config) + self.vocab_size = config.vocab_size + self.lm_head = LMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def compute_router_aux_loss(self, router_logits): + """ + Computes the auxiliary loss for router logits. + This encourages load balancing across experts. + + Args: + router_logits: List of router logits tensors from each MoE layer + Each tensor has shape [batch_size, sequence_length, num_experts] + + Returns: + Auxiliary loss tensor + """ + aux_loss = torch.tensor(0.0, device=router_logits[0].device) + + for layer_idx, layer_router_logits in enumerate(router_logits): + router_probs = torch.softmax(layer_router_logits, dim=-1) + + # Mean routing probability across batch and sequence dimensions + mean_prob = router_probs.mean(dim=[0, 1]) + + # Compute auxiliary loss: combination of load balancing and importance loss + # Load balancing loss: variance of expert usage probabilities (should be uniform) + num_experts = mean_prob.size(0) + ideal_prob = 1.0 / num_experts + balance_loss = torch.sum((mean_prob - ideal_prob) ** 2) + + # Add this layer's auxiliary loss to the total + aux_loss = aux_loss + balance_loss + + # Average over all layers + if len(router_logits) > 0: + aux_loss = aux_loss / len(router_logits) + + return aux_loss + + @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + output_router_logits: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple | CausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Return: + """ + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + # Extract model outputs based on return type + if isinstance(outputs, MoeModelOutputWithPast): + hidden_states = outputs.last_hidden_state + router_logits = outputs.router_logits + elif return_dict: + hidden_states = outputs.last_hidden_state + router_logits = None # No router logits in this case + else: + hidden_states = outputs[0] + router_logits = outputs[4] if output_router_logits and len(outputs) > 4 else None + + # Generate logits + logits = self.lm_head(hidden_states) + logits = logits.float() + + # Calculate loss if labels are provided + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + # Calculate router aux loss if router logits are present + if router_logits is not None and self.config.router_aux_loss_coef > 0: + aux_loss = self.compute_router_aux_loss(router_logits) + loss = loss + aux_loss * self.config.router_aux_loss_coef + + # Handle non-dict return + if not return_dict: + output = (logits,) + if isinstance(outputs, tuple): + output += outputs[1:] # Add all other outputs + return (loss, *output) if loss is not None else output + + # Different output types for MoE vs regular model + if router_logits is not None: + return MoeCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values if return_dict else outputs[1], + hidden_states=outputs.hidden_states + if return_dict + else outputs[2] + if output_hidden_states + else None, + attentions=outputs.attentions + if return_dict + else outputs[3] + if output_attentions + else None, + router_logits=router_logits, + ) + else: + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values if return_dict else outputs[1], + hidden_states=outputs.hidden_states + if return_dict + else outputs[2] + if output_hidden_states + else None, + attentions=outputs.attentions + if return_dict + else outputs[3] + if output_attentions + else None, + ) From 01f4fc15d50a88af30365848325204e546578eb8 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 4 Nov 2025 20:45:01 +0100 Subject: [PATCH 14/49] Make llama3 converter self-contained (no deps on internal Nvidia code) Signed-off-by: Daniel Korzekwa --- .../_compress/decilm/checkpoint_utils.py | 191 ++++++++ .../_compress/decilm/checkpoint_utils_hf.py | 444 ++++++++++++++++++ .../_compress/decilm/conversion_utils.py | 157 +++++++ .../converters/convert_llama3_to_decilm.py | 6 +- 4 files changed, 795 insertions(+), 3 deletions(-) create mode 100644 modelopt/torch/_compress/decilm/checkpoint_utils.py create mode 100644 modelopt/torch/_compress/decilm/checkpoint_utils_hf.py create mode 100644 modelopt/torch/_compress/decilm/conversion_utils.py diff --git a/modelopt/torch/_compress/decilm/checkpoint_utils.py b/modelopt/torch/_compress/decilm/checkpoint_utils.py new file mode 100644 index 000000000..c2a8c08d5 --- /dev/null +++ b/modelopt/torch/_compress/decilm/checkpoint_utils.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import concurrent.futures +import warnings +from functools import partial +from pathlib import Path +from typing import Literal, TypeVar + +import torch +from puzzle_tools.common import infer_weights_dtype +from safetensors.torch import load_file as safe_load_file +from torch import nn +from transformers import AutoTokenizer +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME + +from modelopt.torch._compress.decilm.checkpoint_utils_hf import load_model_config + +SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" +PTH_SUBBLOCKS_DIR_NAME = "subblocks" +STATE_DICT_FILE_NAME = "model.pth" + +warnings.filterwarnings("ignore", "You are using `torch.load` with `weights_only=False`*.") + + +def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]: + checkpoint_dir = _normalize_checkpoint_dir(checkpoint_dir) + + if (state_dict_path := checkpoint_dir / STATE_DICT_FILE_NAME).exists(): + return torch.load(state_dict_path, map_location="cpu", weights_only=False) + + if (safetensors_subblocks_dir := checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists(): + return _load_state_dict_from_subblocks(safetensors_subblocks_dir) + + if (pth_subblocks_dir := checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME).exists(): + return _load_state_dict_from_subblocks(pth_subblocks_dir) + + if (checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists() or ( + checkpoint_dir / SAFE_WEIGHTS_NAME + ).exists(): + from utils.sharded_checkpoint_utils import ( + load_sharded_state_dict, # local import to avoid circular import + ) + + return load_sharded_state_dict(checkpoint_dir) + + raise FileNotFoundError( + f"Couldn't find state dict path or subblocks dir inside {checkpoint_dir}" + ) + + +def _normalize_checkpoint_dir(checkpoint_dir: Path | str) -> Path: + checkpoint_dir = Path(checkpoint_dir) + if checkpoint_dir.is_file(): + checkpoint_dir = checkpoint_dir.parent + return checkpoint_dir + + +def _load_state_dict_from_subblocks(subblocks_dir: Path) -> dict[str, torch.Tensor]: + torch_paths = list(subblocks_dir.glob("*.pth")) + safetensors_paths = list(subblocks_dir.glob("*.safetensors")) + + if len(torch_paths) != 0: + load_fn = partial(torch.load, map_location="cpu", weights_only=False) + file_paths = torch_paths + elif len(safetensors_paths) != 0: + load_fn = safe_load_file + file_paths = safetensors_paths + else: + raise ValueError(f"No tensor files found in {subblocks_dir=}") + + with concurrent.futures.ThreadPoolExecutor() as executor: + state_dict_shards = list(executor.map(load_fn, file_paths)) + + state_dict = {k: v for shard in state_dict_shards for k, v in shard.items()} + return state_dict + + +NNModule = TypeVar("NNModule", bound=nn.Module) + + +def init_module_with_state_dict( + state_dict: dict[str, torch.Tensor], + module_cls: type[NNModule], + *init_args, + **init_kwargs, +) -> NNModule: + weights_dtype = infer_weights_dtype(state_dict) + module = init_empty_module(module_cls, weights_dtype, *init_args, **init_kwargs) + module.load_state_dict(state_dict) + return module + + +def init_empty_module( + module_cls: type[NNModule], + dtype: torch.dtype, + *init_args, + **init_kwargs, +) -> NNModule: + default_dtype = torch.get_default_dtype() + current_device = torch.ones(1).device + torch.set_default_dtype(dtype) + module = skip_init(module_cls, *init_args, device=current_device, **init_kwargs) + torch.set_default_dtype(default_dtype) + return module + + +def skip_init(module_cls, *args, **kwargs) -> nn.Module: + """ + Heavily inspired by torch.nn.utils.skip_init but does not require the module to accept a "device" kwarg. + """ + if not issubclass(module_cls, torch.nn.Module): + raise RuntimeError(f"Expected a Module; got {module_cls}") + + final_device = kwargs.pop("device", "cpu") + with torch.device("meta"): + module = module_cls(*args, **kwargs) + + module = module.to_empty(device=final_device) + return module + + +def is_valid_decilm_checkpoint(checkpoint_dir: Path | str) -> bool: + """Validate that a checkpoint is in DeciLM format (has block_configs). + + Args: + checkpoint_dir: Path to checkpoint directory + + Returns: + True if checkpoint is valid DeciLM format, False otherwise + """ + try: + model_config = load_model_config(checkpoint_dir) + if model_config.block_configs is None: + warnings.warn( + f"Skipping checkpoint '{checkpoint_dir}' - not in DeciLM format (missing block_configs)" + ) + return False + return True + except Exception as e: + warnings.warn(f"Skipping checkpoint '{checkpoint_dir}' - failed to load config: {e}") + return False + + +def copy_tokenizer( + source_dir_or_tokenizer_name: Path | str, + target_dir: Path | str, + on_failure: Literal["raise", "warn"] = "raise", +) -> None: + """ + Prefer loading the tokenizer from huggingface hub (when tokenizer_name.txt file is available) + to avoid collision between transformers versions. + """ + source_tokenizer_name_path = Path(source_dir_or_tokenizer_name) / "tokenizer_name.txt" + if source_tokenizer_name_path.exists(): + source_dir_or_tokenizer_name = source_tokenizer_name_path.read_text().strip() + + tokenizer = None + try: + tokenizer = AutoTokenizer.from_pretrained( + source_dir_or_tokenizer_name, trust_remote_code=True + ) + except Exception: + message = f"Couldn't load tokenizer from '{source_dir_or_tokenizer_name}'" + if on_failure == "raise": + raise FileNotFoundError(message) + else: + warnings.warn(message) + + if tokenizer is not None: + target_dir = Path(target_dir) + target_dir.mkdir(exist_ok=True, parents=True) + tokenizer.save_pretrained(target_dir) + + target_tokenizer_name_path = target_dir / "tokenizer_name.txt" + is_given_tokenizer_name_as_argument = not Path(source_dir_or_tokenizer_name).exists() + if is_given_tokenizer_name_as_argument: + target_tokenizer_name_path.write_text(source_dir_or_tokenizer_name) diff --git a/modelopt/torch/_compress/decilm/checkpoint_utils_hf.py b/modelopt/torch/_compress/decilm/checkpoint_utils_hf.py new file mode 100644 index 000000000..3cfbf93c5 --- /dev/null +++ b/modelopt/torch/_compress/decilm/checkpoint_utils_hf.py @@ -0,0 +1,444 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import concurrent.futures +import fcntl +import os +import shutil +import time +import warnings +from collections import defaultdict +from collections.abc import Callable, Mapping +from pathlib import Path +from typing import Any, BinaryIO + +import torch +from logger import mprint +from puzzle_tools import deci_lm_hf_code +from puzzle_tools.common import infer_weights_dtype +from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from puzzle_tools.robust_json import json_dumps +from safetensors.torch import save_file as safe_save_file +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from utils.post_init_sparse import SparsityMethod + +SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" +PTH_SUBBLOCKS_DIR_NAME = "subblocks" +RELATIVE_SUBBLOCKS_DIR = Path(SAFETENSORS_SUBBLOCKS_DIR_NAME) + + +# TODO: (esegal) Should ask the model for something like this +NON_LAYER_MODULE_TO_FILE_TYPE = { + "model.embed_tokens": "embeddings", + "model.norm": "lm_head", + "lm_head": "lm_head", +} +MODULE_WITHIN_LAYER_TO_FILE_TYPE = { + "input_layernorm": "attention", + "self_attn": "attention", + "post_attention_layernorm": "ffn", + "mlp": "ffn", + "parallel_blocks": "multi_block", +} +LAYERS_MODULE_NAME = "model.layers" + +warnings.filterwarnings("ignore", "You are using `torch.load` with `weights_only=False`*.") + + +def load_checkpoint( + checkpoint_dir: Path | str, + model_config_overrides: dict | None = None, + ignore_unexpected_config_keys: bool = False, +) -> DeciLMForCausalLM: + """ + Unlike AutoModelForCausalLM.from_pretrained, the models loaded by this function use your + local repo code, not the code inside the checkpoint. + """ + from puzzle_tools.checkpoint_utils import load_state_dict # prevent circular import + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + model_config = load_model_config( + checkpoint_dir, model_config_overrides, ignore_unexpected_config_keys + ) + + # Without sparsity we could have done: + # model = DeciLMForCausalLM.from_pretrained(pretrained_model_name_or_path=checkpoint_dir, config=model_config) + state_dict = load_state_dict(checkpoint_dir) + state_dict, sparsity_masks = SparsityMethod.fix_state_dict_inplace(state_dict, verbose=True) + dtype = infer_weights_dtype(state_dict) + model = DeciLMForCausalLM.from_pretrained( + pretrained_model_name_or_path=None, + config=model_config, + state_dict=state_dict, + torch_dtype=dtype, + ) + SparsityMethod().apply_masks(model, sparsity_masks) + + return model + + +def load_model_config( + checkpoint_dir: Path | str, + model_config_overrides: Mapping | None = None, + ignore_unexpected_config_keys: bool = False, +) -> DeciLMConfig: + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + if model_config_overrides is None: + model_config_overrides = {} + + config, unused_kwargs = DeciLMConfig.from_pretrained( + checkpoint_dir, return_unused_kwargs=True, **model_config_overrides + ) + + if not ignore_unexpected_config_keys: + if unused_kwargs: + raise ValueError(f"Unexpected config keys: {unused_kwargs.keys()}") + + return config + + +def save_checkpoint(model: DeciLMForCausalLM, checkpoint_dir: Path | str) -> None: + _save_checkpoint(model.config, model.state_dict(), checkpoint_dir) + + +def _save_checkpoint( + model_config: DeciLMConfig, + state_dict: dict[str, torch.Tensor], + checkpoint_dir: Path | str, + max_workers: int | None = None, # Now optional - will auto-calculate if None +) -> None: + mprint("=== Starting _save_checkpoint detailed profiling ===") + total_start_time = time.time() + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + # Phase 1: Create directory and save config + phase1_start_time = time.time() + checkpoint_dir.mkdir(parents=True, exist_ok=True) + model_config.save_pretrained(checkpoint_dir) + phase1_time = time.time() - phase1_start_time + mprint(f"Phase 1 - Directory creation and config save: {phase1_time:.2f}s") + + # Phase 2: Save subblocks (main model weights) with auto-calculated worker count + phase2_start_time = time.time() + save_subblocks( + state_dict, + checkpoint_dir, + multi_threaded=True, + max_workers=max_workers, # Will auto-calculate if None + ) + phase2_time = time.time() - phase2_start_time + mprint(f"Phase 2 - Save subblocks (model weights): {phase2_time:.2f}s") + + # Phase 3: Save safetensors index + phase3_start_time = time.time() + save_safetensors_index(model_config, checkpoint_dir) + phase3_time = time.time() - phase3_start_time + mprint(f"Phase 3 - Save safetensors index: {phase3_time:.2f}s") + + # Phase 4: Copy HF code + phase4_start_time = time.time() + copy_deci_lm_hf_code(checkpoint_dir) + phase4_time = time.time() - phase4_start_time + mprint(f"Phase 4 - Copy HF code: {phase4_time:.2f}s") + + total_time = time.time() - total_start_time + mprint(f"=== _save_checkpoint completed in {total_time:.2f}s ===") + mprint( + f"Breakdown: Config {phase1_time:.1f}s + Subblocks {phase2_time:.1f}s + " + f"Index {phase3_time:.1f}s + HF code {phase4_time:.1f}s" + ) + mprint( + f"Save percentage breakdown: Config {phase1_time / total_time * 100:.1f}% + " + f"Subblocks {phase2_time / total_time * 100:.1f}% + " + f"Index {phase3_time / total_time * 100:.1f}% + " + f"HF code {phase4_time / total_time * 100:.1f}%" + ) + + # Performance metrics + if phase2_time > 0: + subblocks_percentage = phase2_time / total_time * 100 + actual_workers = max_workers if max_workers else "auto" + mprint( + f"I/O optimization: Subblocks were {subblocks_percentage:.1f}% of total save time " + f"(max_workers={actual_workers})" + ) + + +def split_checkpoint_to_subblocks(checkpoint_dir: Path | str) -> None: + from puzzle_tools.checkpoint_utils import load_state_dict # prevent circular import + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + model_config = load_model_config(checkpoint_dir) + state_dict = load_state_dict(checkpoint_dir) + save_subblocks(state_dict, checkpoint_dir) + + if (index_path := checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists(): + index_path.rename(checkpoint_dir / f"before_splitting.{SAFE_WEIGHTS_INDEX_NAME}") + save_safetensors_index(model_config, checkpoint_dir) + + +def save_subblocks( + state_dict: dict[str, torch.Tensor], + checkpoint_dir: Path | str, + multi_threaded: bool = True, + max_workers: int | None = None, # Now optional - will auto-calculate if None +) -> None: + mprint("=== Starting save_subblocks detailed profiling ===") + subblocks_start_time = time.time() + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + # Step 1: Build weight map + weight_map_start_time = time.time() + weight_map = _build_safetensors_weight_map( + state_dict=state_dict, + non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE, + module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE, + layers_module_name=LAYERS_MODULE_NAME, + ) + weight_name_to_filename = {k: checkpoint_dir / v for k, v in weight_map.items()} + weight_map_time = time.time() - weight_map_start_time + mprint(f" Step 1 - Build weight map: {weight_map_time:.2f}s ({len(weight_map)} mappings)") + + # Step 2: Create subblocks directory + dir_create_start_time = time.time() + subblocks_path = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME + subblocks_path.mkdir(parents=True, exist_ok=True) + dir_create_time = time.time() - dir_create_start_time + mprint(f" Step 2 - Create directory: {dir_create_time:.2f}s") + + # Step 3: Organize tensors by file + organize_start_time = time.time() + filename_to_partial_state_dict = defaultdict(dict) + total_tensor_size = 0 + for weight_name, weight in state_dict.items(): + if weight_name in weight_map: + # Ensure tensor is contiguous and on CPU for faster I/O + tensor = ( + weight.contiguous().cpu() if weight.device.type != "cpu" else weight.contiguous() + ) + filename_to_partial_state_dict[weight_name_to_filename[weight_name]][weight_name] = ( + tensor + ) + total_tensor_size += weight.numel() * weight.element_size() + organize_time = time.time() - organize_start_time + mprint( + f" Step 3 - Organize tensors: {organize_time:.2f}s ({total_tensor_size / (1024**3):.2f}GB total)" + ) + + # Step 4: Prepare save arguments and auto-calculate optimal I/O workers + prepare_start_time = time.time() + safe_save_kwargs = [ + {"tensors": partial_state_dict, "filename": filename, "metadata": {"format": "pt"}} + for filename, partial_state_dict in filename_to_partial_state_dict.items() + ] + + # Auto-calculate optimal I/O workers: min(cpu_count, num_files) + if max_workers is None: + cpu_count = os.cpu_count() or 1 + num_files = len(safe_save_kwargs) + max_workers = min(cpu_count, num_files) + mprint( + f" Auto-calculated I/O workers: min({cpu_count} CPUs, {num_files} files) = {max_workers}" + ) + else: + mprint(f" Using specified I/O workers: {max_workers}") + + prepare_time = time.time() - prepare_start_time + mprint(f" Step 4 - Prepare save args: {prepare_time:.2f}s ({len(safe_save_kwargs)} files)") + + # Step 5: Save files with optimal worker count + save_start_time = time.time() + if multi_threaded: + mprint(f" Using multi-threaded saving with {max_workers} workers...") + + def optimized_safe_save(kwargs): + try: + safe_save_file(**kwargs) + return True + except Exception as e: + mprint(f" Error saving {kwargs['filename']}: {e}") + return False + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + results = list(executor.map(optimized_safe_save, safe_save_kwargs)) + + # Check for any failures + failed_saves = sum(1 for r in results if not r) + if failed_saves > 0: + mprint(f" Warning: {failed_saves} files failed to save") + else: + mprint(" Using single-threaded saving...") + for kwargs in safe_save_kwargs: + safe_save_file(**kwargs) + + save_time = time.time() - save_start_time + mprint(f" Step 5 - Save files: {save_time:.2f}s ({max_workers} workers)") + + subblocks_total_time = time.time() - subblocks_start_time + mprint(f"=== save_subblocks completed in {subblocks_total_time:.2f}s ===") + mprint( + f" Breakdown: WeightMap {weight_map_time:.1f}s + DirCreate {dir_create_time:.1f}s + " + f"Organize {organize_time:.1f}s + Prepare {prepare_time:.1f}s + Save {save_time:.1f}s" + ) + + # Calculate effective I/O speed + io_speed_gbps = (total_tensor_size / (1024**3)) / save_time if save_time > 0 else 0 + mprint(f" Effective I/O speed: {io_speed_gbps:.2f} GB/s ({max_workers} workers)") + mprint(f" Save operation was {save_time / subblocks_total_time * 100:.1f}% of total time") + + +def save_safetensors_index( + model_config: DeciLMConfig, + checkpoint_dir: Path | str, +) -> None: + mprint("=== Starting save_safetensors_index profiling ===") + index_start_time = time.time() + + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + # Step 1: Create fake model on meta device + fake_model_start_time = time.time() + with torch.device("meta"): + fake_model = DeciLMForCausalLM(model_config) + fake_model_time = time.time() - fake_model_start_time + mprint(f" Step 1 - Create fake model: {fake_model_time:.2f}s") + + # Step 2: Build weight map + weight_map_start_time = time.time() + weight_map = _build_safetensors_weight_map( + state_dict=fake_model.state_dict(), + non_layer_module_to_file_type=NON_LAYER_MODULE_TO_FILE_TYPE, + module_within_layer_to_file_type=MODULE_WITHIN_LAYER_TO_FILE_TYPE, + layers_module_name=LAYERS_MODULE_NAME, + ) + weight_map_time = time.time() - weight_map_start_time + mprint(f" Step 2 - Build weight map: {weight_map_time:.2f}s ({len(weight_map)} mappings)") + + # Step 3: Create and write index + write_start_time = time.time() + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + index_path = checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME + index_json = json_dumps(index) + _write_file_process_safe(index_json, index_path) + write_time = time.time() - write_start_time + mprint(f" Step 3 - Write index file: {write_time:.2f}s ({len(index_json)} chars)") + + index_total_time = time.time() - index_start_time + mprint(f"=== save_safetensors_index completed in {index_total_time:.2f}s ===") + mprint( + f" Breakdown: FakeModel {fake_model_time:.1f}s + WeightMap {weight_map_time:.1f}s + Write {write_time:.1f}s" + ) + + +def _write_text(content: str, f: BinaryIO) -> None: + f.write(content.encode("utf-8")) + + +def _write_file_process_safe( + content: Any, + path: Path | str, + write_fn: Callable[[Any, BinaryIO], None] = _write_text, +) -> None: + """ + Write a file in a multi-process safe way. + If another process tries to write the same file using this method, the current process + "gives up" and assumes that the matter is being taken care of by another process. + + write_fn is a function that receives file contents and a binary file object, + and writes the content to the file. It can be _write_text (defined above), or torch.save, + or a similar function (not safetensors.torch.save_file since it expects a path). + """ + with open(path, "wb") as f: + # Try to acquire an exclusive, non-blocking lock + try: + fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: + return # Exit immediately if the lock is not acquired + + write_fn(content, f) # Write the content if lock is acquired + f.flush() # Ensure data is written to disk + + # Release the lock + fcntl.flock(f, fcntl.LOCK_UN) + + +def _build_safetensors_weight_map( + *, + state_dict: dict[str, torch.Tensor], + non_layer_module_to_file_type: dict[str, str], + module_within_layer_to_file_type: dict[str, str], + layers_module_name: str, +) -> dict[str, Path]: + weight_map = {} + unmapped_weight_names = [] + for weight_name in state_dict: + found_match = False + for module_name, file_type in non_layer_module_to_file_type.items(): + if weight_name.startswith(f"{module_name}."): + weight_map[weight_name] = str(RELATIVE_SUBBLOCKS_DIR / f"{file_type}.safetensors") + found_match = True + if not found_match: + if weight_name.startswith(f"{layers_module_name}."): + name_parts = weight_name[len(layers_module_name) + 1 :].split(".") + layer_index = name_parts[0] + name_within_layer = ".".join(name_parts[1:]) + + for module_name, file_type in module_within_layer_to_file_type.items(): + if name_within_layer.startswith(f"{module_name}."): + weight_map[weight_name] = str( + RELATIVE_SUBBLOCKS_DIR / f"block_{layer_index}_{file_type}.safetensors" + ) + found_match = True + + if not found_match: + unmapped_weight_names.append(weight_name) + + if len(unmapped_weight_names) > 0: + raise ValueError( + f"Unmapped weight names: {unmapped_weight_names}\n" + f"Add them to the `non_layer_module_to_file_type` or " + f"`module_within_layer_to_file_type` dictionaries." + ) + + return weight_map + + +# Not really needed +def save_model_config(model_config: DeciLMConfig, checkpoint_dir: Path | str) -> None: + model_config.save_pretrained(checkpoint_dir) + + +def copy_deci_lm_hf_code(output_dir: Path | str) -> None: + """ + Copy the deci_lm_hf_code directory to the output directory. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + code_dir = Path(deci_lm_hf_code.__file__).parent + for path in code_dir.glob("*.py"): + shutil.copy(path, output_dir / path.name) diff --git a/modelopt/torch/_compress/decilm/conversion_utils.py b/modelopt/torch/_compress/decilm/conversion_utils.py new file mode 100644 index 000000000..deb080ea2 --- /dev/null +++ b/modelopt/torch/_compress/decilm/conversion_utils.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re +from collections import defaultdict + +from safetensors.torch import load_file, save_file +from tqdm import tqdm + + +def convert_name(name): + return name.replace("feed_forward", "mlp").replace("language_model.", "") + + +def convert_routed_experts_weight(llama_name, weight): + assert ".experts." in llama_name, "Only use this func to convert weights of routed experts" + llama_name_prefix = llama_name.split(".experts.")[0] + deci_name_prefix = convert_name(llama_name_prefix) + + experts_state_dict = {} + for i_expert, expert_weight in enumerate(weight.unbind(dim=0)): + expert_prefix = f"{deci_name_prefix}.experts.{i_expert}" + if "gate_up_proj" in llama_name: + gate_weight, up_weight = expert_weight.transpose(0, 1).chunk(2, dim=0) + experts_state_dict[f"{expert_prefix}.gate_proj.weight"] = gate_weight.contiguous() + experts_state_dict[f"{expert_prefix}.up_proj.weight"] = up_weight.contiguous() + elif "down_proj" in llama_name: + down_weight = expert_weight.transpose(0, 1) + experts_state_dict[f"{expert_prefix}.down_proj.weight"] = down_weight.contiguous() + else: + raise ValueError(f"Unknown expert weight: {llama_name}") + + return experts_state_dict + + +def get_layer_subblock(param): + if param.startswith("model.embed_tokens."): + return "embeddings" + if param.startswith("lm_head.") or param == "model.norm.weight": + return "lm_head" + m = re.match(r"model\.layers\.(\d+)\.(.+)", param) + if m: + layer, suffix = m.groups() + if suffix.startswith(("self_attn.", "input_layernorm.weight")): + return f"block_{layer}_attention" + elif suffix.startswith(("mlp.", "post_attention_layernorm.weight")): + return f"block_{layer}_ffn" + return None + + +def convert_model_weights_to_decilm(llama_hf_dir, output_dir, is_llama4=False): + index_path = os.path.join(llama_hf_dir, "model.safetensors.index.json") + single_file_path = os.path.join(llama_hf_dir, "model.safetensors") + + # Check if we have a sharded model (with index) or single file model + if os.path.exists(index_path): + # Sharded model - use existing logic + with open(index_path) as f: + index = json.load(f) + param_to_file = index["weight_map"] + all_param_names = list(param_to_file.keys()) + elif os.path.exists(single_file_path): + # Single file model - create a synthetic index + data = load_file(single_file_path) + all_param_names = list(data.keys()) + param_to_file = dict.fromkeys(all_param_names, "model.safetensors") + else: + raise FileNotFoundError( + f"Neither {index_path} nor {single_file_path} found. Cannot determine model format." + ) + name_map = { + name: convert_name(name) + for name in all_param_names + if name.startswith("language_model.") or not is_llama4 + } + + # Reverse map: file -> set of params + file_to_params = defaultdict(set) + for name, file in param_to_file.items(): + file_to_params[file].add(name) + + # Determine subblocks needed + subblocks = defaultdict(list) + for old_name, new_name in name_map.items(): + subblock = get_layer_subblock(new_name) + if subblock: + subblocks[subblock].append((old_name, new_name)) + + # Output directory + out_dir = os.path.join(output_dir, "subblocks_safetensors") + os.makedirs(out_dir, exist_ok=True) + + # New weight index + new_index = {"metadata": {"format": "pt"}, "weight_map": {}} + + # For single file models, load all data once + if os.path.exists(single_file_path) and not os.path.exists(index_path): + all_data = load_file(single_file_path) + else: + all_data = None + + for subblock, param_pairs in tqdm(subblocks.items(), desc="Processing subblocks"): + tensors = {} + + if all_data is not None: + # Single file model - get tensors from pre-loaded data + for old_name, new_name in param_pairs: + if old_name in all_data: + if ".experts." not in old_name: + tensors[new_name] = all_data[old_name] + else: + experts_state_dict = convert_routed_experts_weight( + old_name, all_data[old_name] + ) + tensors.update(experts_state_dict) + else: + # Sharded model - load only needed files for this subblock + param_files = {param_to_file[old] for old, _ in param_pairs} + for file in param_files: + data = load_file(os.path.join(llama_hf_dir, file)) + for old_name, new_name in param_pairs: + if param_to_file[old_name] == file and old_name in data: + if ".experts." not in old_name: + tensors[new_name] = data[old_name] + else: + experts_state_dict = convert_routed_experts_weight( + old_name, data[old_name] + ) + tensors.update(experts_state_dict) + + # Save this subblock + subblock_file = f"{subblock}.safetensors" + save_file(tensors, os.path.join(out_dir, subblock_file)) + + # Update index + for new_name in tensors: + new_index["weight_map"][new_name] = f"subblocks_safetensors/{subblock_file}" + + # Save new index file + with open(os.path.join(output_dir, "model.safetensors.index.json"), "w") as f: + json.dump(new_index, f, indent=2) + + print(f"✅ Finished saving subblocks and index to {output_dir}") diff --git a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py index 96b96f351..1b59ea800 100644 --- a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py +++ b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py @@ -21,11 +21,11 @@ import torch from fire import Fire -from puzzle_tools.checkpoint_utils import copy_tokenizer -from puzzle_tools.checkpoint_utils_hf import copy_deci_lm_hf_code -from puzzle_tools.conversion_utils import convert_model_weights_to_decilm from transformers import LlamaConfig +from modelopt.torch._compress.decilm.checkpoint_utils import copy_tokenizer +from modelopt.torch._compress.decilm.checkpoint_utils_hf import copy_deci_lm_hf_code +from modelopt.torch._compress.decilm.conversion_utils import convert_model_weights_to_decilm from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig """ From c57eed4ad4144d022d5ea81f7b38fe04d24c2d81 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 4 Nov 2025 20:49:58 +0100 Subject: [PATCH 15/49] Add common module Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/common.py | 22 +++++++++++++++++++ .../_compress/decilm/checkpoint_utils.py | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/_compress/common.py diff --git a/modelopt/torch/_compress/common.py b/modelopt/torch/_compress/common.py new file mode 100644 index 000000000..96db57280 --- /dev/null +++ b/modelopt/torch/_compress/common.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def infer_weights_dtype(state_dict: dict[str, torch.Tensor]) -> torch.dtype: + weights_dtype = [p.dtype for p in state_dict.values() if torch.is_floating_point(p)] + weights_dtype = weights_dtype[0] if len(weights_dtype) > 0 else torch.get_default_dtype() + return weights_dtype diff --git a/modelopt/torch/_compress/decilm/checkpoint_utils.py b/modelopt/torch/_compress/decilm/checkpoint_utils.py index c2a8c08d5..d35ec7975 100644 --- a/modelopt/torch/_compress/decilm/checkpoint_utils.py +++ b/modelopt/torch/_compress/decilm/checkpoint_utils.py @@ -21,12 +21,12 @@ from typing import Literal, TypeVar import torch -from puzzle_tools.common import infer_weights_dtype from safetensors.torch import load_file as safe_load_file from torch import nn from transformers import AutoTokenizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from modelopt.torch._compress.common import infer_weights_dtype from modelopt.torch._compress.decilm.checkpoint_utils_hf import load_model_config SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" From 3dc37b31141afefbdbe6c3b825ea4a8895b278fc Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Tue, 4 Nov 2025 21:40:24 +0100 Subject: [PATCH 16/49] module refactoring Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/compress.py | 2 +- .../decilm/converters/convert_llama3_to_decilm.py | 4 ++-- .../torch/_compress/nas/plugins/compress_nas_plugin.py | 4 ++-- .../torch/_compress/{decilm => tools}/checkpoint_utils.py | 4 ++-- .../_compress/{decilm => tools}/checkpoint_utils_hf.py | 8 ++++++-- modelopt/torch/_compress/{ => tools}/common.py | 0 modelopt/torch/_compress/{ => tools}/hydra.py | 0 modelopt/torch/_compress/{ => tools}/runtime.py | 0 .../torch/_compress/nas/plugins/test_nas_convert.py | 2 +- .../torch/_compress/nas/plugins/test_nas_search.py | 2 +- tests/experimental/torch/_compress/test_compress.py | 2 +- 11 files changed, 16 insertions(+), 12 deletions(-) rename modelopt/torch/_compress/{decilm => tools}/checkpoint_utils.py (97%) rename modelopt/torch/_compress/{decilm => tools}/checkpoint_utils_hf.py (98%) rename modelopt/torch/_compress/{ => tools}/common.py (100%) rename modelopt/torch/_compress/{ => tools}/hydra.py (100%) rename modelopt/torch/_compress/{ => tools}/runtime.py (100%) diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index 455cf3f8e..df953bb90 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -28,7 +28,7 @@ from omegaconf import DictConfig from puzzle_tools.runtime import IRuntime -from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir +from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir def compress( diff --git a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py index 1b59ea800..4df9f009a 100644 --- a/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py +++ b/modelopt/torch/_compress/decilm/converters/convert_llama3_to_decilm.py @@ -23,10 +23,10 @@ from fire import Fire from transformers import LlamaConfig -from modelopt.torch._compress.decilm.checkpoint_utils import copy_tokenizer -from modelopt.torch._compress.decilm.checkpoint_utils_hf import copy_deci_lm_hf_code from modelopt.torch._compress.decilm.conversion_utils import convert_model_weights_to_decilm from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.tools.checkpoint_utils import copy_tokenizer +from modelopt.torch._compress.tools.checkpoint_utils_hf import copy_deci_lm_hf_code """ example: diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index d821fbd02..2bb78661f 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -31,8 +31,8 @@ from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) -from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir -from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir +from modelopt.torch._compress.tools.runtime import NativeDdpRuntime from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.mode import ( diff --git a/modelopt/torch/_compress/decilm/checkpoint_utils.py b/modelopt/torch/_compress/tools/checkpoint_utils.py similarity index 97% rename from modelopt/torch/_compress/decilm/checkpoint_utils.py rename to modelopt/torch/_compress/tools/checkpoint_utils.py index d35ec7975..4a05f82bb 100644 --- a/modelopt/torch/_compress/decilm/checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/checkpoint_utils.py @@ -26,8 +26,8 @@ from transformers import AutoTokenizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME -from modelopt.torch._compress.common import infer_weights_dtype -from modelopt.torch._compress.decilm.checkpoint_utils_hf import load_model_config +from modelopt.torch._compress.tools.checkpoint_utils_hf import load_model_config +from modelopt.torch._compress.tools.common import infer_weights_dtype SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" PTH_SUBBLOCKS_DIR_NAME = "subblocks" diff --git a/modelopt/torch/_compress/decilm/checkpoint_utils_hf.py b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py similarity index 98% rename from modelopt/torch/_compress/decilm/checkpoint_utils_hf.py rename to modelopt/torch/_compress/tools/checkpoint_utils_hf.py index 3cfbf93c5..c686c1027 100644 --- a/modelopt/torch/_compress/decilm/checkpoint_utils_hf.py +++ b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py @@ -68,7 +68,9 @@ def load_checkpoint( Unlike AutoModelForCausalLM.from_pretrained, the models loaded by this function use your local repo code, not the code inside the checkpoint. """ - from puzzle_tools.checkpoint_utils import load_state_dict # prevent circular import + from modelopt.torch._compress.tools.checkpoint_utils import ( + load_state_dict, # prevent circular import + ) if not isinstance(checkpoint_dir, Path): checkpoint_dir = Path(checkpoint_dir) @@ -185,7 +187,9 @@ def _save_checkpoint( def split_checkpoint_to_subblocks(checkpoint_dir: Path | str) -> None: - from puzzle_tools.checkpoint_utils import load_state_dict # prevent circular import + from modelopt.torch._compress.tools.checkpoint_utils import ( + load_state_dict, # prevent circular import + ) if not isinstance(checkpoint_dir, Path): checkpoint_dir = Path(checkpoint_dir) diff --git a/modelopt/torch/_compress/common.py b/modelopt/torch/_compress/tools/common.py similarity index 100% rename from modelopt/torch/_compress/common.py rename to modelopt/torch/_compress/tools/common.py diff --git a/modelopt/torch/_compress/hydra.py b/modelopt/torch/_compress/tools/hydra.py similarity index 100% rename from modelopt/torch/_compress/hydra.py rename to modelopt/torch/_compress/tools/hydra.py diff --git a/modelopt/torch/_compress/runtime.py b/modelopt/torch/_compress/tools/runtime.py similarity index 100% rename from modelopt/torch/_compress/runtime.py rename to modelopt/torch/_compress/tools/runtime.py diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py index 7dc2d7228..47ff2531d 100644 --- a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py +++ b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py @@ -24,7 +24,7 @@ import modelopt.torch.nas as mtn from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel -from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.runtime import NativeDdpRuntime # diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py b/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py index 04707d20f..df3c1e485 100644 --- a/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py +++ b/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py @@ -27,7 +27,7 @@ import modelopt.torch.nas as mtn from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel -from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.runtime import NativeDdpRuntime def test_nas_search(project_root_path: Path, tmp_path: Path): diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py index 3d5d6b666..80b7717bf 100644 --- a/tests/experimental/torch/_compress/test_compress.py +++ b/tests/experimental/torch/_compress/test_compress.py @@ -26,7 +26,7 @@ from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) -from modelopt.torch._compress.runtime import NativeDdpRuntime +from modelopt.torch._compress.tools.runtime import NativeDdpRuntime # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) # using a one-click command. From 10ffdfefa8622d5700c3ca21d703f725f33868d0 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 09:14:56 +0100 Subject: [PATCH 17/49] refactoring Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/tools/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/_compress/tools/checkpoint_utils.py b/modelopt/torch/_compress/tools/checkpoint_utils.py index 4a05f82bb..78ee0be00 100644 --- a/modelopt/torch/_compress/tools/checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/checkpoint_utils.py @@ -51,7 +51,7 @@ def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]: if (checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists() or ( checkpoint_dir / SAFE_WEIGHTS_NAME ).exists(): - from utils.sharded_checkpoint_utils import ( + from modelopt.torch._compress.tools.sharded_checkpoint_utils import ( load_sharded_state_dict, # local import to avoid circular import ) From 27a4456c3c71676877fe1e6f5eca62e7610602b8 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 09:42:49 +0100 Subject: [PATCH 18/49] add shared_checkpointing_utils Signed-off-by: Daniel Korzekwa --- .../tools/sharded_checkpoint_utils.py | 416 ++++++++++++++++++ 1 file changed, 416 insertions(+) create mode 100644 modelopt/torch/_compress/tools/sharded_checkpoint_utils.py diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py new file mode 100644 index 000000000..636f25157 --- /dev/null +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -0,0 +1,416 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import json +from collections.abc import Iterable, Mapping +from pathlib import Path +from typing import Literal, cast + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +from huggingface_hub import split_torch_state_dict_into_shards +from logger import mprint +from safetensors import safe_open +from safetensors.torch import load_file as safe_load_file +from safetensors.torch import save_file as safe_save_file +from tqdm import tqdm +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from transformers.utils.hub import cached_file, get_checkpoint_shard_files +from typing_extensions import override +from utils.utils import EmptyInitOnDevice + +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( + DeciLMDecoderLayer, + DeciLMForCausalLM, + rope_type_to_class, +) +from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict +from modelopt.torch._compress.tools.runtime import IRuntime + + +class DummyModule(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.register_load_state_dict_post_hook(self.load_state_dict_post_hook) + + @staticmethod + def load_state_dict_post_hook( + module: torch.nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys + ) -> None: + incompatible_keys.missing_keys.clear() + incompatible_keys.unexpected_keys.clear() + + +class DummyBlock(DummyModule): + def __init__(self, config: DeciLMConfig, block_index: int): + super().__init__() + self.config = config + self.block_index = block_index + + @override + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor | tuple[torch.Tensor, None]: + if self.config.block_return_only_hidden_states: + return x + else: + return x, None + + +class DummyWTE(DummyModule): + def __init__(self, config: DeciLMConfig, dtype: torch.dtype | None = None): + super().__init__() + self.n_embd = config.get_hidden_size() + self.dtype = dtype + + @override + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + B, T = input_ids.shape # noqa: N806 + result = torch.ones((B, T, self.n_embd), dtype=self.dtype, device=input_ids.device) + return result + + +class DummyLMHead(DummyModule): + def __init__(self, config: DeciLMConfig): + super().__init__() + self.vocab_size = config.vocab_size + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, C = x.shape # noqa: N806 + result = torch.ones((B, T, self.vocab_size), dtype=x.dtype, device=x.device) + return result + + +def create_local_shard_(model: DeciLMForCausalLM, owned_block_indexes: set[int]): + all_block_indexes = set(range(len(model.model.layers))) + has_first_block = 0 in owned_block_indexes + has_last_block = max(all_block_indexes) in owned_block_indexes + + unowned_block_indexes = all_block_indexes - owned_block_indexes + for block_index in unowned_block_indexes: + model.model.layers[block_index] = cast( + "DeciLMDecoderLayer", DummyBlock(model.config, block_index) + ) + + if not has_first_block: + model.set_input_embeddings(DummyWTE(model.config)) + + if not has_last_block: + model.model.set_final_layer_norm(nn.Identity()) + if not (model.config.tie_word_embeddings and has_first_block): + model.set_output_embeddings(DummyLMHead(model.config)) + + return model + + +def create_dummy_model( + model_config: DeciLMConfig, + dtype: torch.dtype, +) -> DeciLMForCausalLM: + with torch.device("meta"): + model = DeciLMForCausalLM(model_config) + + rope_cls = rope_type_to_class[model_config.position_embedding_type] + model.model.rotary_emb = rope_cls(config=model.config) + + model.model.set_input_embeddings(DummyWTE(model.config, dtype)) + model.model.set_final_layer_norm(nn.Identity()) + model.set_output_embeddings(DummyLMHead(model.config)) + + for block_index in range(model_config.get_num_hidden_layers()): + model.model.layers[block_index] = DummyBlock(model.config, block_index) + + return model + + +def load_and_shard_model( + runtime: IRuntime, + checkpoint_path: str | Path, + owned_block_indexes: set[int] | Literal["auto"] = "auto", + model_config: DeciLMConfig | None = None, + model_config_overrides: Mapping | None = None, +) -> DeciLMForCausalLM: + checkpoint_path = Path(checkpoint_path) + with runtime.device: + if model_config is None: + model_config = load_model_config( + checkpoint_path, model_config_overrides, ignore_unexpected_config_keys=True + ) + + if owned_block_indexes == "auto": + owned_block_indexes = set( + np.array_split(np.arange(model_config.get_num_hidden_layers()), runtime.world_size)[ + runtime.global_rank + ] + ) + + mprint("Initializing model shards") + model_shard = create_sharded_model( + runtime=runtime, + model_config=model_config, + owned_block_indexes=owned_block_indexes, + ) + + if (checkpoint_path / SAFE_WEIGHTS_NAME).exists() or ( + checkpoint_path / SAFE_WEIGHTS_INDEX_NAME + ).exists(): + mprint("Loading shard state_dict from safetensors") + shard_keys = [ + *[name for name, _ in model_shard.named_parameters()], + *[name for name, _ in model_shard.named_buffers()], + ] + shard_state_dict = load_sharded_state_dict( + model_name_or_path=str(checkpoint_path), + keys_to_load=shard_keys, + device=runtime.device, + ) + + new_names = set(shard_state_dict.keys()) + mprint(f"{new_names=}") + model_shard.load_state_dict(shard_state_dict, assign=True) + + del shard_state_dict + + if model_config.tie_word_embeddings and (0 in owned_block_indexes): + # re-tie the weights in case the connection was severed + model_shard.tie_weights() + else: + mprint("Loading state_dict in main process") + state_dict = load_state_dict(checkpoint_path) if runtime.is_main_process else None + + mprint("Distributing model to shards") + load_state_dict_to_shards( + runtime=runtime, model_shard=model_shard, loaded_state_dict=state_dict + ) + del state_dict + + model_shard.type(runtime.dtype) + + params_on_meta_device = [ + param_name + for param_name, param in model_shard.named_parameters() + if param.device == torch.device("meta") + ] + assert len(params_on_meta_device) == 0, ( + f"[global_rank={runtime.global_rank}] Couldn't load params {params_on_meta_device}" + ) + + return model_shard + + +def create_sharded_model( + runtime: IRuntime, + model_config: DeciLMConfig, + owned_block_indexes: set[int], + device: str | torch.device | None = "meta", + dtype: torch.dtype | None = torch.float32, +): + if isinstance(device, str): + device = torch.device(device) + + runtime.wait_for_everyone() + + with EmptyInitOnDevice(device="meta", dtype=dtype): + model = DeciLMForCausalLM(model_config) + create_local_shard_(model=model, owned_block_indexes=owned_block_indexes) + + if device != torch.device("meta"): + local_shard_state_dict = { + k: torch.empty_like(v, device=device) for k, v in model.state_dict().items() + } + + model.load_state_dict(local_shard_state_dict, assign=True) + + return model + + +def load_state_dict_to_shards( + runtime: IRuntime, model_shard: torch.nn.Module, loaded_state_dict: dict | None = None +) -> None: + from sewing_kit.utils import distributed_isend_obj, distributed_recv_obj + + model_shard.to("meta") + local_state_dict_keys = list(model_shard.state_dict().keys()) + + if runtime.is_main_process: + gathered_state_dict_keys = [None] * runtime.world_size + torch.distributed.gather_object(local_state_dict_keys, gathered_state_dict_keys) + + assert loaded_state_dict is not None + loaded_state_dict = {k.replace("_orig_mod.", ""): v for k, v in loaded_state_dict.items()} + + works: list[torch.distributed.Work] = [] + for i, shard_keys in enumerate(gathered_state_dict_keys[1:]): + process_id = i + 1 + shard_state_dict = {k: v for k, v in loaded_state_dict.items() if k in shard_keys} + process_works = distributed_isend_obj(shard_state_dict, process_id) + works.extend(process_works) + + for work in works: + work.wait() + + shard_state_dict = { + k: v for k, v in loaded_state_dict.items() if k in local_state_dict_keys + } + else: + torch.distributed.gather_object(local_state_dict_keys) + shard_state_dict = distributed_recv_obj() + + print(f"{runtime.global_rank=} loaded state_dict shard") + + missing_keys, unexpected_keys = model_shard.load_state_dict( + shard_state_dict, strict=False, assign=True + ) + assert len(unexpected_keys) == 0 + assert all("dummy_param" in key for key in missing_keys) + + model_shard.to(runtime.device) + + runtime.wait_for_everyone() + + +def save_sharded_model( + runtime: IRuntime, + model_shard: torch.nn.Module | dict[str, torch.Tensor], + out_path: str | Path, +): + """ + out_path is usually output_checkpoint_path / "model.safetensors" + """ + runtime.wait_for_everyone() + + if isinstance(model_shard, torch.nn.Module): + shard_state_dict = model_shard.state_dict() + elif isinstance(model_shard, dict): + shard_state_dict = model_shard + else: + raise ValueError(f"Unrecognized model shard type: {type(model_shard)}") + + shard_state_dict = {k: v.cpu() for k, v in shard_state_dict.items()} + total_shard_size = sum( + weight.numel() * weight.element_size() for weight in shard_state_dict.values() + ) + + num_shards = runtime.world_size + idx = runtime.global_rank + + out_path = Path(out_path) + shard_file = out_path.with_stem(f"{out_path.stem}-{idx + 1:05d}-of-{num_shards:05d}") + + shard_metadata = { + "total_shard_size": total_shard_size, + "shard_keys": list(shard_state_dict.keys()), + "shard_file": str(shard_file), + } + + if runtime.is_main_process: + shard_metadatas = [{} for _ in range(runtime.world_size)] + torch.distributed.gather_object(shard_metadata, shard_metadatas, dst=0) + total_size = sum(x["total_shard_size"] for x in shard_metadatas) + metadata = {"total_size": total_size} + weight_map: dict[str, str] = {} + for shard_metadata in shard_metadatas: + weight_map.update( + {k: Path(shard_metadata["shard_file"]).name for k in shard_metadata["shard_keys"]} + ) + + index = {"metadata": metadata, "weight_map": weight_map} + index_path = Path(str(out_path) + ".index.json") + index_path.write_text(json.dumps(index, indent=2)) + + else: + torch.distributed.gather_object(shard_metadata, dst=0) + + if out_path.suffix == ".safetensors": + safe_save_file(shard_state_dict, shard_file, metadata={"format": "pt"}) + else: + torch.save(shard_state_dict, shard_file) + + runtime.wait_for_everyone() + + +def save_sharded_state_dict( + state_dict: dict[str, torch.Tensor], + save_directory: str | Path, + max_shard_size: str = "10GB", +) -> None: + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + state_dict = {k: v.cpu() for k, v in state_dict.items()} + + state_dict_split = split_torch_state_dict_into_shards(state_dict, max_shard_size=max_shard_size) + + for shard_filename, param_names in tqdm( + state_dict_split.filename_to_tensors.items(), desc="saving sharded state dict" + ): + shard_path = save_directory / shard_filename + shard = {param_name: state_dict[param_name] for param_name in param_names} + safe_save_file(shard, shard_path, metadata={"format": "pt"}) + + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + index_path = save_directory / SAFE_WEIGHTS_INDEX_NAME + index_path.write_text(json.dumps(index, indent=2)) + + +def load_sharded_state_dict( + model_name_or_path: str | Path, + keys_to_load: Iterable[str] | None = None, + device: torch.device | str = "cpu", +) -> dict[str, torch.Tensor]: + """ + keys_to_load: entire state_dict if None, else partial state_dict containing only these keys + """ + shard_paths = _resolve_shard_paths(model_name_or_path) + # print(f"shard_paths: {shard_paths}") + partial_state_dict = {} + for safetensors_path in shard_paths: + if keys_to_load is None: + shard = safe_load_file(safetensors_path) + partial_state_dict.update(shard) + else: + with safe_open(safetensors_path, framework="pt", device=str(device)) as f: + for key in f: + if key in keys_to_load: + partial_state_dict[key] = f.get_tensor(key) + return partial_state_dict + + +def _resolve_shard_paths(model_name_or_path: str) -> list[str]: + try: + unsharded_path = cached_file(model_name_or_path, SAFE_WEIGHTS_NAME) + return [unsharded_path] + except OSError: + index_path = cached_file(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + shard_paths, _ = get_checkpoint_shard_files(model_name_or_path, index_path) + return shard_paths + + +def is_in_safetensors_format(checkpoint_dir: Path) -> bool: + return len(list(checkpoint_dir.glob("*.safetensors"))) > 0 + + +def load_state_dict_shapes(model_name_or_path: str | Path) -> dict[str, tuple]: + shard_paths = _resolve_shard_paths(model_name_or_path) + state_dict_shapes = {} + for safetensors_path in shard_paths: + with safe_open(safetensors_path, framework="pt") as f: + for key in f: + state_dict_shapes[key] = tuple(f.get_tensor(key).shape) + return state_dict_shapes From b0e22b7c603c19d418bdf097934bb1d1903d58b8 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 09:58:22 +0100 Subject: [PATCH 19/49] Add json tools Signed-off-by: Daniel Korzekwa --- .../_compress/tools/checkpoint_utils_hf.py | 9 +-- modelopt/torch/_compress/tools/robust_json.py | 72 +++++++++++++++++++ 2 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/_compress/tools/robust_json.py diff --git a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py index c686c1027..1e2faf9dd 100644 --- a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py @@ -27,15 +27,16 @@ import torch from logger import mprint -from puzzle_tools import deci_lm_hf_code -from puzzle_tools.common import infer_weights_dtype -from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from puzzle_tools.robust_json import json_dumps from safetensors.torch import save_file as safe_save_file from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from utils.post_init_sparse import SparsityMethod +from modelopt.torch._compress.decilm.deci_lm_hf_code import deci_lm_hf_code +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch._compress.tools.common import infer_weights_dtype + SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" PTH_SUBBLOCKS_DIR_NAME = "subblocks" RELATIVE_SUBBLOCKS_DIR = Path(SAFETENSORS_SUBBLOCKS_DIR_NAME) diff --git a/modelopt/torch/_compress/tools/robust_json.py b/modelopt/torch/_compress/tools/robust_json.py new file mode 100644 index 000000000..4c126b631 --- /dev/null +++ b/modelopt/torch/_compress/tools/robust_json.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import argparse +import dataclasses +import datetime +import inspect +import json +from enum import Enum +from pathlib import Path +from typing import Any + +try: + from omegaconf import DictConfig, ListConfig, OmegaConf + + OMEGACONF_AVAILABLE = True +except ImportError: + OMEGACONF_AVAILABLE = False + + +class RobustJSONEncoder(json.JSONEncoder): + def default(self, o): + if dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + if isinstance(o, Path): + return str(o) + if isinstance(o, Enum): + return o.name + if isinstance(o, argparse.Namespace): + return vars(o) + if type(o).__name__ == "dtype": + return str(o) + if OMEGACONF_AVAILABLE and isinstance(o, (DictConfig, ListConfig)): + return OmegaConf.to_container(o, resolve=True) + if inspect.isfunction(o) or inspect.ismethod(o): + if o.__module__ == "__main__": + # User-defined function in main — fallback to just the name + return o.__name__ + return f"{o.__module__}.{o.__qualname__}" + if isinstance(o, datetime.timedelta): + return str(o) + return super().default(o) + + +def json_dumps(obj: Any) -> str: + return json.dumps(obj, cls=RobustJSONEncoder, indent=2) + + +def json_dump(obj: Any, path: Path | str) -> None: + path = Path(path) + path.parent.mkdir(exist_ok=True, parents=True) + json_text = json_dumps(obj) + path.write_text(json_text) + + +def json_load(path: Path | str) -> dict: + path = Path(path) + text = path.read_text() + return json.loads(text) From 52e7827efaab2dda79ffddd296070e7373867af4 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 10:13:13 +0100 Subject: [PATCH 20/49] add logger Signed-off-by: Daniel Korzekwa --- .../_compress/tools/checkpoint_utils_hf.py | 10 +- modelopt/torch/_compress/tools/logger.py | 172 ++++++++++++++++++ 2 files changed, 177 insertions(+), 5 deletions(-) create mode 100644 modelopt/torch/_compress/tools/logger.py diff --git a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py index 1e2faf9dd..97e592efc 100644 --- a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py @@ -26,16 +26,16 @@ from typing import Any, BinaryIO import torch -from logger import mprint +from puzzle_tools import deci_lm_hf_code +from puzzle_tools.common import infer_weights_dtype +from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from puzzle_tools.robust_json import json_dumps from safetensors.torch import save_file as safe_save_file from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from utils.post_init_sparse import SparsityMethod -from modelopt.torch._compress.decilm.deci_lm_hf_code import deci_lm_hf_code -from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM -from modelopt.torch._compress.tools.common import infer_weights_dtype +from modelopt.torch._compress.tools.logger import mprint SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" PTH_SUBBLOCKS_DIR_NAME = "subblocks" diff --git a/modelopt/torch/_compress/tools/logger.py b/modelopt/torch/_compress/tools/logger.py new file mode 100644 index 000000000..0ec65fef5 --- /dev/null +++ b/modelopt/torch/_compress/tools/logger.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import inspect +import logging +import os +import sys + +import torch.distributed.launch # noqa: F401 + +logging.getLogger("fsspec.local").setLevel(logging.ERROR) +logging.getLogger("websockets.client").setLevel(logging.WARN) +logging.getLogger("websockets.server").setLevel(logging.WARN) +logging.getLogger("websockets.server:connection").setLevel(logging.WARN) + + +class LogColors: + BLUE = "\033[94m" + CYAN = "\033[96m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + RESET = "\033[0m" + + +class DistributedLogger(logging.Logger): + verbosity = logging.ERROR + + def __init__(self, name, level=logging.DEBUG): + super().__init__(name, level) + self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + self.global_rank = int(os.environ.get("RANK", 0)) + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + + def dist_log(self, msg: str, ranks: str = "main"): + """ + Log parameter msg with the given ranks. + parameter ranks: + "all": log with all ranks + "main": log with only rank 0 in node 0 + "last": log with only rank -1 in node 0 + "local_main": log with only rank 0 in all nodes + """ + # print(msg, ranks) + if ranks not in ["all", "main", "local_main", "last"]: + raise NotImplementedError( + f"Could not broadcast msg {msg} - " + f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}" + ) + # All ranks to print + if ranks == "all": + pass + + # Only main rank at node 0 to print + elif ( + (ranks == "main" and self.global_rank != 0) + or (ranks == "last" and self.local_rank != self.world_size - 1) + or (ranks == "local_main" and self.local_rank != 0) + ): + return + + message_source = self.get_caller_location() + + if self.global_rank == 0: + color = LogColors.GREEN + elif self.local_rank == self.world_size - 1: + color = LogColors.RED + else: + color = LogColors.CYAN + + self.info(f"{color}[rank-{self.global_rank}]{LogColors.RESET}[{message_source}]\t{msg}") + + # def dist_warning(self, msg): + # if self.verbosity <= logging.WARNING: + # self.warning(f"[rank-{self.global_rank}] " + msg) + + @staticmethod + def get_caller_location() -> str: + # Get the caller's stack frame + frame = inspect.currentframe() + + # f_back -> class method, 2 x f_back -> utils method, 3 x f_back -> original source + caller_frame = frame.f_back.f_back.f_back + + # Get the filename and line number from the caller's stack frame + filename = os.path.basename(caller_frame.f_code.co_filename) + lineno = caller_frame.f_lineno + return f"{filename}:{lineno}" + + +# Initialize logger +logging.setLoggerClass(DistributedLogger) +logger = logging.getLogger(__name__) +logger.propagate = False + +formatter = logging.Formatter("[%(asctime)s]%(message)s") +handler = logging.StreamHandler(sys.stdout) +handler.setFormatter(formatter) +handler.setLevel(logging.DEBUG) +logger.addHandler(handler) + +# Manually edit torch logger +torch_logger = logging.getLogger("torch") +torch_logger.handlers = logger.handlers +torch_logger.propagate = False + +# Manually edit deepspeed logger + +# Show some love to Mac & Windows users who can't easily install deepspeed ;) +# This is allowing running tests on Mac & Windows and train in non-DDP +try: + from deepspeed.utils import logger as deepspeed_logger + + deepspeed_logger.handlers = logger.handlers + deepspeed_logger.propagate = False +except ImportError: + # If deepspeed is not installed - no op + pass + +# Define a custom function to redirect warnings to logger +# def custom_warning_handler(message, category, filename, lineno, file=None, line=None): +# logger.dist_warning(f'{category.__name__}: {message} (in {filename}, line {lineno})') + + +# Use the custom warning handler +# warnings.showwarning = custom_warning_handler + +logger: DistributedLogger + + +def aprint(msg: str | None): + """ + All ranks from all nodes prints + """ + return logger.dist_log(msg=msg, ranks="all") + + +def lmprint(msg: str | None): + """ + All local main ranks prints (rank 0 in each node) + """ + return logger.dist_log(msg=msg, ranks="local_main") + + +def mprint(msg: str | None): + """ + Master prints only (rank 0 in node 0) + """ + return logger.dist_log(msg=msg, ranks="main") + + +def lprint(msg: str | None): + """ + Last rank prints only (rank -1 in node 0) + """ + return logger.dist_log(msg=msg, ranks="last") From f5c1c87b05b06ebfccefb2f1e65125be991b2c71 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 10:19:55 +0100 Subject: [PATCH 21/49] import refactoring Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/tools/checkpoint_utils_hf.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py index 97e592efc..9238653f2 100644 --- a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py @@ -26,16 +26,16 @@ from typing import Any, BinaryIO import torch -from puzzle_tools import deci_lm_hf_code -from puzzle_tools.common import infer_weights_dtype -from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM -from puzzle_tools.robust_json import json_dumps from safetensors.torch import save_file as safe_save_file from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from utils.post_init_sparse import SparsityMethod +from modelopt.torch._compress.decilm import deci_lm_hf_code +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch._compress.tools.common import infer_weights_dtype from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.robust_json import json_dumps SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" PTH_SUBBLOCKS_DIR_NAME = "subblocks" From 0aa6320b57b582ad148d76e1d19437861ab353db Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 10:38:28 +0100 Subject: [PATCH 22/49] add post_init_sparse module Signed-off-by: Daniel Korzekwa --- .../_compress/tools/checkpoint_utils_hf.py | 2 +- .../torch/_compress/tools/post_init_sparse.py | 123 ++++++++++++++++++ 2 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/_compress/tools/post_init_sparse.py diff --git a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py index 9238653f2..d0eb75d40 100644 --- a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py @@ -28,13 +28,13 @@ import torch from safetensors.torch import save_file as safe_save_file from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from utils.post_init_sparse import SparsityMethod from modelopt.torch._compress.decilm import deci_lm_hf_code from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from modelopt.torch._compress.tools.common import infer_weights_dtype from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.post_init_sparse import SparsityMethod from modelopt.torch._compress.tools.robust_json import json_dumps SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" diff --git a/modelopt/torch/_compress/tools/post_init_sparse.py b/modelopt/torch/_compress/tools/post_init_sparse.py new file mode 100644 index 000000000..aeddbb78c --- /dev/null +++ b/modelopt/torch/_compress/tools/post_init_sparse.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import torch +from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from torch import nn +from torch.nn.utils.prune import custom_from_mask + + +class SparsityMethod: + def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + gets a model state_dict, returns a state_dict-like mask_dict with masks + """ + + @staticmethod + def fix_state_dict_inplace(state_dict, verbose=False, change_dtype=False): + sparsity_masks = {} + for name in list(state_dict.keys()): + original_name = name.replace("_orig", "") + mask_name = original_name + "_mask" + if name[-4:] == "orig" and mask_name in state_dict: + val = state_dict[name] + mask = state_dict[name[:-4] + "mask"] + val[mask == 0] = 0 + sparsity = (val == 0).sum() / mask.numel() + sparsity_masks[original_name[:-7]] = mask + if verbose: + print(f"fix_state_dict_inplace: {name} {sparsity=}") + del state_dict[mask_name] + del state_dict[name] + state_dict[original_name] = val + if change_dtype: + for name in state_dict: + state_dict[name] = state_dict[name].to(torch.bfloat16) + return state_dict, sparsity_masks + + def filter_function(self): + pass + + def apply_masks(self, model: nn.Module, mask_dict: dict[str, torch.Tensor]) -> None: + for name, module in model.named_modules(): + if name in mask_dict: + custom_from_mask(module, "weight", mask_dict[name].to(module.weight.device)) + print(name) + print(torch.sum(mask_dict[name]) / mask_dict[name].numel()) + + def do_sparsity(self, model: DeciLMForCausalLM, mask_dict=None): + full_name_layers = [] + for block_idx, block_config in enumerate(model.config.block_configs): + ffn_names = block_config.ffn.sparsify # layers_to_sparsify_pattern[block_idx] + att_name = block_config.attention.sparsify + block = model.model.layers[block_idx] + if hasattr(block, "mlp"): + for name, m in block.mlp.named_modules(): + if isinstance(m, torch.nn.Linear) and self.filter_function(name, ffn_names): + full_name_layers.append( + "model.layers." + str(block_idx) + "." + "mlp." + name + ) + if hasattr(block, "self_attn"): + for name, m in block.self_attn.named_modules(): + if isinstance(m, torch.nn.Linear) and self.filter_function(name, att_name): + full_name_layers.append( + "model.layers." + str(block_idx) + "." + "self_attn." + name + ) + + if mask_dict is None: + state_dict_for_sparsifying = { + k.rstrip(".weight"): v + for k, v in model.state_dict().items() + if k.rstrip(".weight") in full_name_layers + } + mask_dict = self.calculate_masks(state_dict_for_sparsifying) + # print('Apply sparsity') + # print(full_name_layers) + # print(model.state_dict().keys()) + # print(list(mask_dict.keys())) + + self.apply_masks(model, mask_dict) + + +class SparsityMethod2o4(SparsityMethod): + def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + gets a model state_dict, returns a state_dict-like mask_dict with masks + """ + mask_dict = {} + for key, val in state_dict.items(): + orig_size = val.shape + scores = val.flatten() ** 2 + mask = self.create_mask(scores) + mask = mask.reshape(orig_size) + mask_dict[key] = mask + return mask_dict + + def create_mask(self, score, value=0): + score = score # .cpu() + orig_size = score.shape + score = score.view(-1, 4) + mask = torch.zeros(score.shape) + values, indices = torch.topk(score, 2, dim=1) + rows = torch.arange(mask.size(0)).unsqueeze(-1) + mask[rows, indices] = 1 + mask = mask.view(orig_size) + return mask # dev = score.device, return mask.to(dev) + + @staticmethod + def filter_function(name, modules_to_sparsify_in_block): + if modules_to_sparsify_in_block is None: + return False + return name in modules_to_sparsify_in_block From 35d0dbccb70ed9163f2297639931f7871d7bc988 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 10:39:59 +0100 Subject: [PATCH 23/49] Add post_init_sparse Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/tools/post_init_sparse.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/_compress/tools/post_init_sparse.py b/modelopt/torch/_compress/tools/post_init_sparse.py index aeddbb78c..6e52ebe22 100644 --- a/modelopt/torch/_compress/tools/post_init_sparse.py +++ b/modelopt/torch/_compress/tools/post_init_sparse.py @@ -14,10 +14,11 @@ # limitations under the License. # mypy: ignore-errors import torch -from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from torch import nn from torch.nn.utils.prune import custom_from_mask +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM + class SparsityMethod: def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: From e39a1ad23b2074a916a50f92f9a91e10c05e0001 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 10:51:21 +0100 Subject: [PATCH 24/49] merginy hydra.py and hydra_utils.py Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/compress.py | 2 +- .../nas/plugins/compress_nas_plugin.py | 2 +- modelopt/torch/_compress/tools/hydra_utils.py | 81 +++++++++++++++++++ .../torch/_compress/compress_test_utils.py | 3 +- .../torch/_compress/test_compress.py | 2 +- 5 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/_compress/tools/hydra_utils.py diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index df953bb90..7d955c5ca 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -28,7 +28,7 @@ from omegaconf import DictConfig from puzzle_tools.runtime import IRuntime -from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir +from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir def compress( diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 2bb78661f..27e76c584 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -31,7 +31,7 @@ from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) -from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir +from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir from modelopt.torch._compress.tools.runtime import NativeDdpRuntime from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField diff --git a/modelopt/torch/_compress/tools/hydra_utils.py b/modelopt/torch/_compress/tools/hydra_utils.py new file mode 100644 index 000000000..64c403565 --- /dev/null +++ b/modelopt/torch/_compress/tools/hydra_utils.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for hydra config initialization. +""" + +import datetime +import random +from pathlib import Path + +from hydra import compose, initialize, initialize_config_dir +from hydra.utils import get_object +from omegaconf import DictConfig, OmegaConf + + +def warmup_steps(tokens: int, block: int, mbs: int, pct: float = 0.05) -> int: + """ + Calculate warmup steps based on total tokens, block size, micro batch size, and warmup percentage. + Used as a resolver in hydra configs. + """ + steps = (int(tokens) // int(block)) // int(mbs) + w = pct * steps + return max(1, round(w)) + + +def register_hydra_resolvers(): + OmegaConf.register_new_resolver("to_path", lambda x: Path(x)) + OmegaConf.register_new_resolver( + "random_int", lambda low, high: random.randint(int(low), int(high)) + ) + OmegaConf.register_new_resolver( + "timedelta_minutes", lambda x: datetime.timedelta(minutes=x) if x is not None else None + ) + OmegaConf.register_new_resolver("warmup_steps", lambda t, b, m, p: warmup_steps(t, b, m, p)) + OmegaConf.register_new_resolver("get_object", lambda x: get_object(x)) + + +def initialize_hydra_config_for_dir( + config_dir: str, config_name: str, overrides: list[str] +) -> DictConfig: + """Initialize a hydra config from an absolute path for a config directory + + Args: + config_dir (str): + config_name (str): + overrides (List[str]): + + Returns: + DictConfig: + """ + + with initialize_config_dir(version_base=None, config_dir=config_dir): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args + + +def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig: + with initialize(version_base=None, config_path=config_path): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args diff --git a/tests/experimental/torch/_compress/compress_test_utils.py b/tests/experimental/torch/_compress/compress_test_utils.py index f0704f6c8..160098922 100644 --- a/tests/experimental/torch/_compress/compress_test_utils.py +++ b/tests/experimental/torch/_compress/compress_test_utils.py @@ -19,9 +19,10 @@ import torch from datasets import Dataset, DatasetDict -from puzzle_tools.hydra_utils import register_hydra_resolvers from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers + def setup_test_model_and_data( project_root_path: Path, diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py index 80b7717bf..8ce807ac2 100644 --- a/tests/experimental/torch/_compress/test_compress.py +++ b/tests/experimental/torch/_compress/test_compress.py @@ -46,7 +46,7 @@ # pip install mip # pip install lru-dict # -# export PYTHONPATH=$PYTHONPATH:/workspace/puzzletron/v1 +# export PYTHONPATH=$PYTHONPATH:.:/workspace/puzzletron/v1 # # pytest -s -v ./tests/experimental/torch/_compress/test_compress.py::test_compress -o addopts="" From 1bd0c67eafa64a3483f52fbd0735568448af12d6 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 12:42:56 +0100 Subject: [PATCH 25/49] Add integrationt test for attention pruning Signed-off-by: Daniel Korzekwa --- .../torch/_compress/compress_test_utils.py | 10 +- .../_compress/nas/plugins/test_nas_convert.py | 81 ++++++++++++- .../_compress/nas/plugins/test_nas_search.py | 8 +- .../configs/Llama-3_1-8B-attn-pruning.yaml | 108 ++++++++++++++++++ ...-8B.yaml => Llama-3_1-8B-ffn-pruning.yaml} | 0 .../torch/_compress/test_compress.py | 8 +- 6 files changed, 197 insertions(+), 18 deletions(-) create mode 100644 tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml rename tests/experimental/torch/_compress/resources/configs/{Llama-3_1-8B.yaml => Llama-3_1-8B-ffn-pruning.yaml} (100%) diff --git a/tests/experimental/torch/_compress/compress_test_utils.py b/tests/experimental/torch/_compress/compress_test_utils.py index 160098922..ce22e1864 100644 --- a/tests/experimental/torch/_compress/compress_test_utils.py +++ b/tests/experimental/torch/_compress/compress_test_utils.py @@ -33,8 +33,6 @@ def setup_test_model_and_data( Path, Path, Path, - Path, - str, ]: """ Setup the test model and data for the compress NAS search. @@ -46,8 +44,8 @@ def setup_test_model_and_data( runtime: the runtime to use for the test Returns: - tuple[Path, Path, Path, Path, str]: - the puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name + tuple[Path, Path, Path]: + the puzzle_dir, llama_checkpoint_path, dataset_path """ # Register Hydra custom resolvers (needed for config resolution) @@ -58,8 +56,6 @@ def setup_test_model_and_data( puzzle_dir = tmp_path llama_checkpoint_path = puzzle_dir / "input_model/llama" dataset_path = puzzle_dir / "dummy_dataset" - hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs" - hydra_config_name = "Llama-3_1-8B" if rank == 0: # Setup puzzle_dir and dataset @@ -77,8 +73,6 @@ def setup_test_model_and_data( puzzle_dir, llama_checkpoint_path, dataset_path, - hydra_config_dir, - hydra_config_name, ) diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py index 47ff2531d..cf284cfc8 100644 --- a/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py +++ b/tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py @@ -31,24 +31,28 @@ # See tests/experimental/torch/_compress/test_compress.py for instructions on how to run this test # TODO: Remove those instructions once this test runs automatically on CI # -def test_nas_convert(project_root_path: Path, tmp_path: Path): +def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): spawn_multiprocess_job( size=torch.cuda.device_count(), - job=partial(_test_nas_convert_multiprocess_job, project_root_path, tmp_path), + job=partial(_test_nas_convert_ffn_pruning_multiprocess_job, project_root_path, tmp_path), backend="nccl", ) -def _test_nas_convert_multiprocess_job( +def _test_nas_convert_ffn_pruning_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): with NativeDdpRuntime( dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) ) as runtime: # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = ( - setup_test_model_and_data(project_root_path, tmp_path, rank, runtime) + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, runtime ) + hydra_config_dir = ( + project_root_path / "tests/experimental/torch/_compress/resources/configs" + ) + hydra_config_name = "Llama-3_1-8B-ffn-pruning" # # Run the mnt.convert() step @@ -86,4 +90,69 @@ def _test_nas_convert_multiprocess_job( runtime.wait_for_everyone() - print("PYTEST SUMMARY: test_nas_convert() test has finished successfully") + print("PYTEST SUMMARY: test_nas_convert_ffn_pruning() test has finished successfully") + + +def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_nas_convert_attn_pruning_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_nas_convert_attn_pruning_multiprocess_job( + project_root_path: Path, tmp_path: Path, rank: int, size: int +): + with NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) + ) as runtime: + # Setup the test model and data. + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, runtime + ) + hydra_config_dir = ( + project_root_path / "tests/experimental/torch/_compress/resources/configs" + ) + hydra_config_name = "Llama-3_1-8B-attn-pruning" + + # + # Run the mnt.convert() step + # + input_model = CompressModel() + mtn.convert( + input_model, + mode=[ + ( + "compress", + { + "puzzle_dir": str(puzzle_dir), + "input_model_path": str(llama_checkpoint_path), + "hydra_config_dir": str(hydra_config_dir), + "hydra_config_name": hydra_config_name, + "dataset_path": str(dataset_path), + }, + ) + ], + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/attn_independent_kv_head_contribution/" + f"100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() + + # assertions for the pruning_ckpts step + assert (puzzle_dir / "ckpts/n_heads_in_group8").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group16").exists() + assert (puzzle_dir / "ckpts/n_heads_in_group32").exists() + + runtime.wait_for_everyone() + + print("PYTEST SUMMARY: test_nas_convert_attn_pruning() test has finished successfully") diff --git a/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py b/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py index df3c1e485..4a6a3ecce 100644 --- a/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py +++ b/tests/experimental/torch/_compress/nas/plugins/test_nas_search.py @@ -45,9 +45,13 @@ def _test_nas_search_multiprocess_job( dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) ) as runtime: # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = ( - setup_test_model_and_data(project_root_path, tmp_path, rank, runtime) + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, runtime ) + hydra_config_dir = ( + project_root_path / "tests/experimental/torch/_compress/resources/configs" + ) + hydra_config_name = "Llama-3_1-8B-ffn-pruning" # # Run the mnt.convert() step diff --git a/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml b/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml new file mode 100644 index 000000000..21a3486f0 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-attn-pruning.yaml @@ -0,0 +1,108 @@ +defaults: + - pruning: attn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B.yaml b/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml similarity index 100% rename from tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B.yaml rename to tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B-ffn-pruning.yaml diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py index 8ce807ac2..46747674c 100644 --- a/tests/experimental/torch/_compress/test_compress.py +++ b/tests/experimental/torch/_compress/test_compress.py @@ -64,9 +64,13 @@ def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, ran dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) ) as runtime: # Setup the test model and data. - puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = ( - setup_test_model_and_data(project_root_path, tmp_path, rank, runtime) + puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, runtime ) + hydra_config_dir = ( + project_root_path / "tests/experimental/torch/_compress/resources/configs" + ) + hydra_config_name = "Llama-3_1-8B-ffn-pruning" # Convert the Llama model to DeciLM model. if rank == 0: From 0ecd52b24475cf04c80490e21f259de0d2584a77 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 12:56:26 +0100 Subject: [PATCH 26/49] add score_pruning_activations Signed-off-by: Daniel Korzekwa --- .../score_pruning_activations.py | 173 ++++++++++++++++++ .../nas/plugins/compress_nas_plugin.py | 2 +- 2 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/_compress/activation_scoring/score_pruning_activations.py diff --git a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py new file mode 100644 index 000000000..b574136bb --- /dev/null +++ b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import hydra +import torch +from logger import mprint +from omegaconf import DictConfig +from puzzle_tools.hydra_utils import register_hydra_resolvers +from puzzle_tools.runtime import BaseRuntime, NativeDDP_Runtime +from puzzle_tools.validate_model import validate_model +from utils.dist_utils import is_distributed +from utils.parsing import format_global_config + + +def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: + """ + Determine if the activation hook method has proper checkpoint support implemented. + + Args: + activation_hooks_kwargs: Hook configuration + + Returns: + bool: True if the hook method has save_state/load_state implemented + """ + method = activation_hooks_kwargs.get("method", "") + + # Methods with implemented checkpoint support + supported_methods = { + "iterative", # IterativeChannelContributionHook: save_state/load_state implemented + "independent", # IndependentChannelContributionHook: save_state/load_state implemented + "stats", # RouterStatsHook: save_state/load_state implemented + "ranked_choice_voting", # RankedChoiceVotingHook: save_state/load_state implemented + } + + return method in supported_methods + + +def check_scoring_completion( + activations_log_dir: str, runtime, activation_hooks_kwargs=None +) -> bool: + """ + Check if scoring is already completed by looking for the expected output files. + Also checks if the scoring method is safe for resume. + + Args: + activations_log_dir: Directory where activation logs should be stored + runtime: Runtime object for distributed processing + activation_hooks_kwargs: Hook configuration to check if resume is safe + + Returns: + bool: True if scoring is completed (has rank files and args.json) + """ + # Only check completion on main process (or if no distributed runtime) + if runtime is None or runtime.is_main_process: + log_dir = Path(activations_log_dir) + + # Check if directory exists + if not log_dir.exists(): + return False + + # Check for rank files (at least rank_0.pth should exist) + rank_files = list(log_dir.glob("rank_*.pth")) + + if not rank_files: + return False + + # Check for args.json (created by main process) + args_file = log_dir / "args.json" + has_args_json = args_file.exists() + + # Check for completion: if we have rank files and args.json, scoring is complete + if rank_files and has_args_json: + # Add optional completion info for debugging + mprint(f"Found completed scoring in {activations_log_dir}") + mprint(f" - Found {len(rank_files)} rank files") + mprint(f" - Found args.json: {has_args_json}") + + return True + + return False + + +def should_skip_scoring_completely(cfg: DictConfig, runtime) -> bool: + """ + Determine if we should skip scoring entirely (only if 100% complete). + Partial progress should proceed to validate_model for proper resume. + + Args: + cfg: Configuration object + runtime: Runtime object for distributed processing + + Returns: + bool: True if we should skip scoring (100% completed), False if we should run/resume it + """ + # Check if activations_log_dir is specified + if not hasattr(cfg.pruning, "activations_log_dir") or cfg.pruning.activations_log_dir is None: + mprint("No activations_log_dir specified, running scoring") + return False + + # Check for force restart flag + force_restart = getattr(cfg.pruning, "force_restart_scoring", False) + if force_restart: + mprint("Force restart flag set, will restart scoring regardless of existing artifacts") + return False + + # Get hook configuration to check if resume is mathematically safe + activation_hooks_kwargs = getattr(cfg.pruning, "activation_hooks_kwargs", {}) + + # Check if scoring is already completed + is_completed = check_scoring_completion( + cfg.pruning.activations_log_dir, runtime, activation_hooks_kwargs + ) + + # Broadcast the result to all processes in distributed mode + if runtime is not None and runtime.world_size > 1: + should_skip = [is_completed] # Use list for mutable object + torch.distributed.broadcast_object_list(should_skip, src=0) + is_completed = should_skip[0] + + if is_completed: + mprint("Scoring 100% completed, skipping...") + + return is_completed + + +# Old progress tracking removed - checkpoint manager handles all progress tracking + + +def launch_score_activations(cfg: DictConfig, runtime): + # Check if we should skip scoring entirely (only if 100% complete) + if should_skip_scoring_completely(cfg, runtime): + return + + mprint("Starting pruning activation scoring...") + + # The checkpoint manager inside validate_model handles all progress tracking + validate_model(args=cfg.pruning, runtime=runtime) + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + cfg = hydra.utils.instantiate(cfg) + mprint(format_global_config(cfg, title="Score Pruning Activations")) + + _runtime = ( + NativeDDP_Runtime( + dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") + ) + if is_distributed() + else BaseRuntime(dtype=torch.bfloat16) + ) + with _runtime as runtime: + launch_score_activations(cfg, runtime) + runtime.wait_for_everyone() + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 27e76c584..146274114 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -23,11 +23,11 @@ import build_library_and_stats import mip_and_realize_models import pruning_ckpts -import score_pruning_activations import scoring import torch from torch import nn +from modelopt.torch._compress.activation_scoring import score_pruning_activations from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) From 278c6b735962ae0cafcd9227011df72e009106f9 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 13:05:40 +0100 Subject: [PATCH 27/49] import refactoring Signed-off-by: Daniel Korzekwa --- .../activation_scoring/score_pruning_activations.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py index b574136bb..99acc9ad5 100644 --- a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py +++ b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py @@ -17,14 +17,15 @@ import hydra import torch -from logger import mprint from omegaconf import DictConfig -from puzzle_tools.hydra_utils import register_hydra_resolvers -from puzzle_tools.runtime import BaseRuntime, NativeDDP_Runtime from puzzle_tools.validate_model import validate_model from utils.dist_utils import is_distributed from utils.parsing import format_global_config +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.runtime import BaseRuntime, NativeDdpRuntime + def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: """ @@ -157,7 +158,7 @@ def main(cfg: DictConfig) -> None: mprint(format_global_config(cfg, title="Score Pruning Activations")) _runtime = ( - NativeDDP_Runtime( + NativeDdpRuntime( dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") ) if is_distributed() From 7a0af16aa00e56897094b5317b02e729acba726d Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 13:15:15 +0100 Subject: [PATCH 28/49] add dist_utils Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/utils/dist_utils.py | 30 ++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 modelopt/torch/_compress/utils/dist_utils.py diff --git a/modelopt/torch/_compress/utils/dist_utils.py b/modelopt/torch/_compress/utils/dist_utils.py new file mode 100644 index 000000000..84f8f2bab --- /dev/null +++ b/modelopt/torch/_compress/utils/dist_utils.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch.distributed as dist + + +def is_distributed(): + """ + From torchtune.utils.is_distributed() : https://docs.pytorch.org/torchtune/0.2/generated/torchtune.utils.is_distributed.html + """ + port = os.environ.get("MASTER_PORT", "") + addr = os.environ.get("MASTER_ADDR", "") + size = int(os.environ.get("WORLD_SIZE", 1)) + rank = int(os.environ.get("RANK", -1)) + avlb = dist.is_available() + return bool(port and addr and size > 1 and rank >= 0 and avlb) From 0f0cbbdc76e1d7eb7a63b02bc18ad458844aabad Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 19:59:03 +0100 Subject: [PATCH 29/49] Add validate_model Signed-off-by: Daniel Korzekwa --- .../score_pruning_activations.py | 4 +- .../tools/sharded_checkpoint_utils.py | 8 +- .../torch/_compress/tools/validate_model.py | 416 ++++++++ modelopt/torch/_compress/utils/utils.py | 936 ++++++++++++++++++ 4 files changed, 1358 insertions(+), 6 deletions(-) create mode 100644 modelopt/torch/_compress/tools/validate_model.py create mode 100644 modelopt/torch/_compress/utils/utils.py diff --git a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py index 99acc9ad5..3617bdb1c 100644 --- a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py +++ b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py @@ -18,13 +18,13 @@ import hydra import torch from omegaconf import DictConfig -from puzzle_tools.validate_model import validate_model -from utils.dist_utils import is_distributed from utils.parsing import format_global_config from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.runtime import BaseRuntime, NativeDdpRuntime +from modelopt.torch._compress.tools.validate_model import validate_model +from modelopt.torch._compress.utils.dist_utils import is_distributed def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py index 636f25157..08ee7e9d4 100644 --- a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -24,6 +24,7 @@ import torch.nn as nn from huggingface_hub import split_torch_state_dict_into_shards from logger import mprint +from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file @@ -31,16 +32,15 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from transformers.utils.hub import cached_file, get_checkpoint_shard_files from typing_extensions import override -from utils.utils import EmptyInitOnDevice from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, - DeciLMForCausalLM, rope_type_to_class, ) from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict from modelopt.torch._compress.tools.runtime import IRuntime +from modelopt.torch._compress.utils.utils import EmptyInitOnDevice class DummyModule(nn.Module): @@ -386,7 +386,7 @@ def load_sharded_state_dict( partial_state_dict.update(shard) else: with safe_open(safetensors_path, framework="pt", device=str(device)) as f: - for key in f: + for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable if key in keys_to_load: partial_state_dict[key] = f.get_tensor(key) return partial_state_dict @@ -411,6 +411,6 @@ def load_state_dict_shapes(model_name_or_path: str | Path) -> dict[str, tuple]: state_dict_shapes = {} for safetensors_path in shard_paths: with safe_open(safetensors_path, framework="pt") as f: - for key in f: + for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable state_dict_shapes[key] = tuple(f.get_tensor(key).shape) return state_dict_shapes diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py new file mode 100644 index 000000000..b852299db --- /dev/null +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -0,0 +1,416 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import textwrap +from copy import deepcopy +from pathlib import Path + +import torch.distributed +from omegaconf import DictConfig +from torch import nn +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from utils.activation_hooks.utils import register_activation_hooks +from utils.data.dataloaders import create_validation_dataloader +from utils.parsing import simple_parse_args_string +from utils.validate_runtime_pipeline import HiddenStatesAndLMHead, calculate_losses_pipeline +from utils.validation import calculate_losses + +from modelopt.torch._compress.tools.checkpoint_utils_hf import load_checkpoint +from modelopt.torch._compress.tools.logger import aprint, mprint +from modelopt.torch._compress.tools.runtime import IRuntime, NativeDdpRuntime +from modelopt.torch._compress.tools.sharded_checkpoint_utils import load_and_shard_model + +# #TODO:Import slack from root utils directory +# root_path = os.path.join(os.path.dirname(__file__), "..", "..") +# if root_path not in sys.path: +# sys.path.append(root_path) +# from utils.slack import send_slack_message + +""" +Two goals: +1) Calculate lm loss and token accuracy for a model. +May raise lots of NCCL warnings when it finishes, don't be alarmed. +Can be used to validate a lit-llama model or a HuggingFace model. +If HuggingFace, automatically uses pipeline parallelism via device_map="auto". +If lit-llama, will use pipeline parallelism if called with --pipeline_parallel and run using torchrun. + +2) Register hooks to capture the inputs and the outputs of pytorch modules. +For example, to collect activations scores for various layers (ffn, layer_norm, etc.) +that are used for pruning (ffn_hidden_size, embedding_pruning, etc). +See --activations_log_dir and --activation_hooks_kwargs args arguments. + +Usage: +====== + +########################################################### +### For lit-llama multi gpu: +### Use torchrun and the flag --pipeline_parallel. +### Example: + +MODEL="/lustre/fsw/portfolios/coreai/projects/coreai_nvfm_llm/models/meta-llama/Llama-3.1-8B-Instruct" +DATASET="/lustre/fsw/portfolios/coreai/projects/coreai_nvfm_llm/datasets/diverse_mix/releases/v0.4_mini" + +NUM_GPUS=$(python -c "import torch; print(torch.cuda.device_count())") +torchrun --rdzv-backend=static --master-addr 127.0.0.1 --master-port 8754 --nproc-per-node=${NUM_GPUS} -m \ + scripts.validate_model --pipeline_parallel \ + --model_name_or_path=${MODEL} \ + --dataset_path ${DATASET} \ + --block_size 1024 --eval_samples 32 --seed 42 --shuffle_seed 444 --bos_rate 0.5 --data_column conversation \ + --val_dataset_name=__auto__ --micro_batch_size 1 \ + 2>&1 | tee -a "${MODEL}/validate_model_outputs.txt" + + + +########################################################### +### For lit-llama multi gpu with teacher similarity scores: +### Use torchrun and the flag --pipeline_parallel. +### Specify --teacher_dir. +### Example: + +TEACHER="/lustre/fsw/portfolios/coreai/projects/coreai_nvfm_llm/models/meta-llama/Meta-Llama-3-8B-Instruct" +MODEL="/lustre/fsw/portfolios/coreai/projects/coreai_nvfm_llm/models/meta-llama/Llama-3.1-8B-Instruct" +DATASET="/lustre/fsw/portfolios/coreai/projects/coreai_nvfm_llm/datasets/diverse_mix/releases/v0.4_mini" + +NUM_GPUS=$(python -c "import torch; print(torch.cuda.device_count())") +torchrun --rdzv-backend=static --master-addr 127.0.0.1 --master-port 8754 --nproc-per-node=${NUM_GPUS} -m \ + scripts.validate_model --pipeline_parallel \ + --model_name_or_path=${MODEL} \ + --teacher_dir=${TEACHER} \ + --dataset_path ${DATASET} \ + --block_size 8192 --eval_samples 32 --seed 42 --shuffle_seed 444 --bos_rate 0.5 --data_column conversation \ + --val_dataset_name=__auto__ --micro_batch_size 1 + + +########################################################### +### For huggingface models (device_map="auto") or lit-llama single gpu: +### Use python (not torchrun) and do not use the flag --pipeline_parallel. +python -m scripts.validate_model \ + --all --the --other --args +### Example: + +MODEL="/lustre/fsw/portfolios/coreai/projects/coreai_nvfm_llm/models/meta-llama/Llama-3.1-8B-Instruct-HF" +DATASET="/lustre/fsw/portfolios/coreai/projects/coreai_nvfm_llm/datasets/diverse_mix/releases/v0.4_mini" + +python -m \ + scripts.validate_model \ + --model_name_or_path=${MODEL} \ + --dataset_path ${DATASET} \ + --block_size 1024 --eval_samples 32 --seed 42 --shuffle_seed 444 --bos_rate 0.5 --data_column conversation \ + --val_dataset_name=__auto__ --micro_batch_size 1 \ + 2>&1 | tee -a "${MODEL}/validate_model_outputs.txt" + + + +########################################################### +### Calculate activations log (channel contribution) for lit-llama multi gpu: +NUM_GPUS=$(python -c "import torch; print(torch.cuda.device_count())") +torchrun --rdzv-backend=static --master-addr 127.0.0.1 --master-port 8754 --nproc-per-node=${NUM_GPUS} \ + -m scripts.validate_model --pipeline_parallel \ + --model_name_or_path $MODEL \ + --dataset_path $DATASET \ + --block_size 8192 --eval_samples 4096 --seed 42 --bos_rate 0.5 --data_column conversation \ + --val_dataset_name=train --shuffle_seed 81436 --micro_batch_size 4 \ + --activations_log_dir activations_log_${FILESAFE_MODEL_NAME} + + +""" + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name_or_path", + type=str, + default=None, + help="Required unless a model is passed to the function", + ) + parser.add_argument("--dataset_path", type=str, required=True) + + parser.add_argument( + "--teacher_dir", + type=str, + default=None, + help="If given, calculates teacher similarity scores (kl_div etc.) " + "Only works with lit-llama models.", + ) + parser.add_argument("--output_dir_name", type=str, default="validation") + parser.add_argument( + "--calculate_full_score_ablations", + action="store_true", + help="Calculates a diverse suite of teacher similarity scores. " + "By default only a small suite is calculated, which is good for most use-cases.", + ) + + parser.add_argument("--tokenizer_name", type=str, default=None) + parser.add_argument("--data_column", type=str, default="content") + parser.add_argument("--fim_rate", type=float, default=0) + parser.add_argument("--fim_spm_rate", type=float, default=0) + parser.add_argument("--eval_samples", type=int, default=None) + parser.add_argument("--block_size", type=int, default=4096) + parser.add_argument("--micro_batch_size", type=int, default=4) + parser.add_argument("--val_dataset_name", type=str, default="__auto__") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--source_datasets_to_discard", nargs="+", type=str) + parser.add_argument("--bos_rate", type=float, default=1.0) + parser.add_argument("--shuffle_seed", type=int, default=None) + parser.add_argument("--varlen", action="store_true") + parser.add_argument("--pipeline_parallel", action="store_true") + parser.add_argument("--write_results", action="store_true") + parser.add_argument("--activations_log_dir", type=str, default=None) + parser.add_argument( + "--activation_hooks_kwargs", + type=str, + default=None, + help="Comma separated string arguments, e.g. `arg1=val1,arg2=val2`", + ) + parser.add_argument( + "--calc_losses_on_cpu", + action="store_true", + help="Very slow, not recommended. Can help avoid OOM.", + ) + return parser + + +def parse_args() -> argparse.Namespace: + parser = build_arg_parser() + args, unknown_args = parser.parse_known_args() + return args + + +@torch.no_grad() +def validate_model( + args: argparse.Namespace | DictConfig, + model: PreTrainedModel | None = None, + tokenizer: PreTrainedTokenizerBase | None = None, + target_hidden_states_per_batch: list[torch.Tensor] | None = None, + return_hidden_states: bool = False, + runtime: IRuntime | None = None, + calculate_full_score_ablations: bool = False, + val_dataloader: DataLoader | None = None, +) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: + if val_dataloader is None: + val_dataloader = ( + prepare_dataloader(args, tokenizer) + if (runtime is None or runtime.is_main_process) + else None + ) + validation_full_iters = ( + args.eval_samples // args.micro_batch_size + ) # model pipeline, single data rank + + model = prepare_model(args, model, runtime) + + just_model_forward = False + checkpoint_manager = None + activation_hooks = None + + if args.activations_log_dir is not None: + activation_hooks_kwargs = ( + simple_parse_args_string(args.activation_hooks_kwargs) + if isinstance(args.activation_hooks_kwargs, str) + else args.activation_hooks_kwargs + ) + activation_hooks_kwargs["validation_full_iters"] = validation_full_iters + + # Create activation hooks first + activation_hooks, hook_class = register_activation_hooks( + model=model, activation_hooks_kwargs=activation_hooks_kwargs + ) + + # Create checkpoint manager with hooks + from utils.checkpoint_manager import ScoringCheckpointManager + + mprint( + f"Creating checkpoint manager with {len(activation_hooks)} hooks for dir: {args.activations_log_dir}" + ) + checkpoint_manager = ScoringCheckpointManager( + checkpoint_dir=args.activations_log_dir, + runtime=runtime, + activation_hooks=activation_hooks, + checkpoint_interval=50, # Save every 50 batches + ) + + # Load existing checkpoint if available + mprint("Attempting to load existing checkpoint...") + checkpoint_data = checkpoint_manager.load_checkpoint() + if checkpoint_data: + mprint(f"Checkpoint loaded successfully: {checkpoint_data}") + else: + mprint("No checkpoint found, starting fresh") + just_model_forward = True + model.lm_head = nn.Identity() + + if runtime is None: + losses, hidden_states_per_batch = calculate_losses( + model=model, + dataloader=val_dataloader, + checkpoint_manager=checkpoint_manager, + ) + else: + losses, hidden_states_per_batch = calculate_losses_pipeline( + runtime=runtime, + stitched_model=model, + dataloader=val_dataloader, + target_hidden_states_per_batch=target_hidden_states_per_batch, + return_hidden_states=return_hidden_states, + calculate_full_score_ablations=calculate_full_score_ablations, + calc_on_cpu=args.calc_losses_on_cpu, + just_model_forward=just_model_forward, + checkpoint_manager=checkpoint_manager, + ) + + if losses is not None: + avg_losses = {loss_name: loss_log["avg"] for loss_name, loss_log in losses.items()} + + results_str = f""" + validate_model: + {args.model_name_or_path=} + Average losses = {avg_losses} + Actual num samples = {len(next(iter(losses.values()))["per_sample"])} + {args=} + """ + results_str = textwrap.dedent(results_str) + aprint(results_str) + if args.write_results: + Path(f"{args.model_name_or_path}/validate_model_results.txt").write_text(results_str) + # TODO: send_slack_message(results_str) + + if args.activations_log_dir is not None: + hook_class.dump_activations_logs(activation_hooks, args.activations_log_dir, args, runtime) + + return losses, hidden_states_per_batch + + +def prepare_model( + args: argparse.Namespace, + model: PreTrainedModel | None = None, + runtime: IRuntime | None = None, +) -> nn.Module: + if model is None: + assert args.model_name_or_path is not None + if runtime is not None: + model = load_and_shard_model( + runtime, + args.model_name_or_path, + model_config_overrides={"block_size": args.block_size}, + ) + else: + try: + model = load_checkpoint( + args.model_name_or_path, + model_config_overrides={"block_size": args.block_size}, + ignore_unexpected_config_keys=True, + ) + model.to("cuda") + except FileNotFoundError: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + torch_dtype="auto", + device_map="auto", + trust_remote_code=True, + ) + + model.eval() + return model + + +def prepare_dataloader( + args: argparse.Namespace, + tokenizer: PreTrainedTokenizerBase | None = None, +) -> DataLoader: + if tokenizer is None: + tokenizer_name = getattr(args, "tokenizer_name", None) + assert (tokenizer_name is not None) or (args.model_name_or_path is not None) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name or args.model_name_or_path, trust_remote_code=True + ) + + val_dataloader = create_validation_dataloader( + accelerator=None, + seed=args.seed, + tokenizer=tokenizer, + block_size=args.block_size, + dataset=args.dataset_path, + content_field=args.data_column, + fim_rate=args.fim_rate, + fim_spm_rate=args.fim_spm_rate, + micro_batch_size=args.micro_batch_size, + eval_samples=args.eval_samples, + dataset_name=args.val_dataset_name, + source_datasets_to_discard=args.source_datasets_to_discard, + bos_rate=args.bos_rate, + varlen=args.varlen, + shuffle_seed=args.shuffle_seed, + load_dataset_fn=args.load_dataset_fn, + ) + + return val_dataloader + + +def validate_model_with_teacher_similarity_scores( + args: argparse.Namespace, + runtime: IRuntime, +): + from puzzle_tools.validation_utils import ( + validate_model_and_extract_hidden_states, + validate_model_with_teacher_similarity_metrics, # importing here to avoid cyclic import + ) + + output_dir = Path(args.model_name_or_path) / args.output_dir_name + + teacher_val_args = deepcopy(args) + teacher_val_args.model_name_or_path = args.teacher_dir + teacher_hidden_states = validate_model_and_extract_hidden_states( + args=teacher_val_args, + model=None, + tokenizer=None, + output_dir=output_dir, + model_name="teacher", + runtime=runtime, + ) + + validate_model_with_teacher_similarity_metrics( + args=args, + model=None, + tokenizer=None, + target_hidden_states_per_batch=teacher_hidden_states, + output_dir=output_dir, + model_name="this_model", + runtime=runtime, + calculate_full_score_ablations=args.calculate_full_score_ablations, + ) + + +def main(): + args = parse_args() + if args.pipeline_parallel: + with NativeDdpRuntime(dtype=torch.bfloat16) as runtime: + if args.teacher_dir is None: + validate_model(args=args, runtime=runtime) + else: + validate_model_with_teacher_similarity_scores(args=args, runtime=runtime) + else: + validate_model(args=args, runtime=None) + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/_compress/utils/utils.py b/modelopt/torch/_compress/utils/utils.py new file mode 100644 index 000000000..197e47068 --- /dev/null +++ b/modelopt/torch/_compress/utils/utils.py @@ -0,0 +1,936 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# mypy: ignore-errors + +import dataclasses +import functools +import json +import os +import pathlib +import pickle # nosec B403 - pickle used for PyTorch model serialization +import re +import warnings +from copy import deepcopy +from io import BytesIO +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import pandas as pd +import torch +from fire import Fire +from puzzle_tools.deci_lm_hf_code.block_config import AttentionConfig, BlockConfig, FFNConfig +from tqdm import tqdm + + +def calculate_kv_dim(n_heads_in_group: int, n_head: int, n_embd: int) -> int: + if n_heads_in_group is None: + return 0 + n_kv_heads = n_head // n_heads_in_group + head_size = n_embd // n_head + kv_dim = 2 * n_kv_heads * head_size + return kv_dim + + +def raise_unknown_subblock_config_error(subblock_config: Any) -> None: + raise ValueError( + f"subblock_config should be an instance of FFNConfig or AttentionConfig, instead got {type(subblock_config)}" + ) + + +def sizeof_dtype(dtype: torch.dtype | str) -> int | float: + """returns the size in bytes of the given dtype""" + if dtype == "nvfp4": + return 1 / 1.7 + return torch.tensor([], dtype=dtype).element_size() + + +def _sort_by_prefix(strings, prefix_order): + """Sorts a list of strings based on a given order of their prefix.""" + prefix_index = {prefix: i for i, prefix in enumerate(prefix_order)} + sorted_strings = sorted( + strings, key=lambda s: (prefix_index.get(s.split(".")[0], len(prefix_order)), s) + ) + return sorted_strings + + +def load_puzzle_solutions( + puzzle_dir: str | pathlib.Path, + include_patterns: list[str] = (), + exclude_patterns: list[str] = (), + drop_empty_columns: bool = True, + default_generation_seq_len: int | None = None, +) -> pd.DataFrame: + """Loads all puzzle solutions from a puzzle directory into a DataFrame. + + Supports `include_patterns` and `exclude_patterns` which are lists of expressions that will be + evaluated with `re.findall()`. Can be used together. + """ + puzzle_dir = pathlib.Path(puzzle_dir) + + # Find all matching solutions files. + solution_files = list(puzzle_dir.glob("**/solutions.json")) + if exclude_patterns: + solution_files = [ + f + for f in solution_files + if not any(re.findall(pat, str(f)) for pat in exclude_patterns) + ] + if include_patterns: + solution_files = [ + f for f in solution_files if any(re.findall(pat, str(f)) for pat in include_patterns) + ] + + solution_records = [] + for solution_file in solution_files: + with open(solution_file) as f: + solutions = json.load(f) + + for sol_idx, sol in enumerate(solutions): + mip_constraints = ( + sol["puzzle_args"].get("mip_constraints") or sol["puzzle_args"]["constraints"] + ) # For backward compatibility with old puzzle formats. + human_constraints = sol["puzzle_args"].get("human_constraints") + args = ( + sol.get("subblock_stats") or sol["puzzle_args"]["subblock_stats_args"] + ) # For partial backward compatibility - if its an old run without "subblock_stats", + # we at least take the args used to filter the subblock_stats. + total_costs = sol["total_costs"] + total_value = sol["total_value"] + chosen_items = sol["chosen_items"] + + if "generation_seq_len" not in args and default_generation_seq_len is None: + raise ValueError( + "Trying to parse an old puzzle dir (without the full `subblock_stats` in " + "solution.json) without explicitly providing `default_generation_seq_len`. " + "For clarity of the calculated results we don't assume this is 1000 by default. " + "Please provide the argument." + ) + + gen_seq_len = args.get("generation_seq_len", default_generation_seq_len) + calculated_cost_throughput = ( + gen_seq_len * args["batch_size"] / (total_costs["stats.runtime_ms"] / 1000) + if "batch_size" in args and "stats.runtime_ms" in total_costs + else np.nan + ) + total_costs["throughput"] = calculated_cost_throughput + + calculated_constraint_throughput = ( + gen_seq_len * args["batch_size"] / (mip_constraints["stats.runtime_ms"] / 1000) + if "batch_size" in args and "stats.runtime_ms" in mip_constraints + else np.nan + ) + + record = ( + { + f"mip_constraint.{k.removeprefix('stats.')}": v + for k, v in mip_constraints.items() + } + | {f"human_constraint.{k}": v for k, v in human_constraints.items()} + | { + f"args.{k}": v + for k, v in args.items() + if k in ("batch_size", "gpu", "prefill_seq_len", "generation_seq_len") + } + | {f"costs.{k.removeprefix('stats.')}": v for k, v in total_costs.items()} + | total_value # keys in `total_value` already includes the `total_value.` prefix + | { + "calculated_constraint_throughput": calculated_constraint_throughput, + "solution_idx": sol_idx, + "solution_file": str(solution_file), + "chosen_items": chosen_items, + } + ) + solution_records.append(record) + + if not solution_records: + raise ValueError( + f"No solutions were found in {puzzle_dir} ({include_patterns=}, {exclude_patterns=})" + ) + + df = pd.DataFrame.from_dict(solution_records) + df = df.reindex( + _sort_by_prefix( + df.columns, ["human_constraint", "mip_constraint", "args", "costs", "metrics"] + ), + axis="columns", + ) + + if drop_empty_columns: + df = df.dropna(axis="columns", how="all") + + return df + + +def load_json(file_path: str): + if not os.path.exists(file_path): + print("file does not exist {file_path}") + return None + + with open(file=file_path) as f: + return json.load(f) + + +def save_json(obj: object, file_path: str): + with open(file=file_path, mode="w") as f: + return json.dump(obj, f) + + +def print_solution(solution_path: str, solution_id=0): + solution = load_json(solution_path) + if solution is not None: + sol = solution + if isinstance(solution, list): + sol = solution[solution_id] + elif isinstance(solution, dict) and solution.get("puzzle_solution") is not None: + sol = solution.get("puzzle_solution") + print(sol["solution_repr"]) + + sol["total_costs"]["stats.num_params"] = ( + f"{sol['total_costs']['stats.num_params'] / 1e9:.2f}B" + ) + + print("costs are: ", sol["total_costs"]) + print("sum kl_div is : ", sol["total_value"]) + sum_kl_div = sol["total_value"]["metrics.kl_div_loss"] + # actual kl_div + validation_path = solution_path.replace( + ".json", f"--validation/solution_{solution_id}.json" + ) + if os.path.exists(validation_path): + validation = load_json(validation_path) + kl_div_loss = validation["kl_div_loss"]["avg"] + lm_loss = validation["lm_loss"]["avg"] + print(f"actual {kl_div_loss=}") + print(f"actual {lm_loss=}") + print(f"{sum_kl_div:.3f}, {kl_div_loss:.3f}, {lm_loss:.3f}") + + +def get_block_repr(parent_layer_indices, single_sequence_replacement): + block_variant_name = "" + if isinstance(single_sequence_replacement, list): + for block_config in single_sequence_replacement: + block_variant_name += block_config_to_str(deepcopy(block_config)) + + else: + block_variant_name = block_config_to_str(deepcopy(single_sequence_replacement)) + return f"block(s) {parent_layer_indices}: " + block_variant_name + + +def load_scores(validation_dir: str) -> pd.DataFrame: + rows = [] + for solution_path in tqdm(list(pathlib.Path(validation_dir).glob("solution*.json"))): + solution_info = json.loads(solution_path.read_text()) + solution_id = re.search(r"solution_(\d+)", solution_path.stem).group(1) + # return solution_info + # print(solution_path) + # print(solution_info["puzzle_solution"]["single_sequence_replacement"].keys()) + replacement_info = solution_info["puzzle_solution"]["single_sequence_replacement"] + parent_layer_indices = replacement_info["parent_layer_indices"] + # kl_div_loss = solution_info['kl_div_loss'] + scores = { + k: v["avg"] for k, v in solution_info.items() if isinstance(v, dict) and v.get("avg") + } + + child_block_configs = deepcopy(replacement_info["child_block_configs"].copy()) + if isinstance(child_block_configs, list): + block_variant_name = "" + for block_config in child_block_configs: + block_variant_name += block_config_to_str(deepcopy(block_config)) + else: + block_variant_name = block_config_to_str(deepcopy(block_config)) + rows.append( + { + "solution_id": solution_id, + "parent_layer_indices": parent_layer_indices, + "block_variant_name": block_variant_name, + "block_repr": get_block_repr(parent_layer_indices, block_config), + "block_config": replacement_info["child_block_configs"], + **scores, + } + ) + + return pd.DataFrame(rows) + + +def load_val_results(validation_dir: str) -> pd.DataFrame: + rows = [] + validation_dir = pathlib.Path(validation_dir) + sol_paths = list(validation_dir.rglob("solution_0.json")) + teacher_paths = list(validation_dir.rglob("teacher.json")) + for solution_path in tqdm(sol_paths + teacher_paths): + solution_info = json.loads(solution_path.read_text()) + solution_id = ( + str(solution_path.parent.parent.relative_to(validation_dir)) + if solution_path.name != "teacher.json" + else "teacher" + ) + scores = { + k: v["avg"] for k, v in solution_info.items() if isinstance(v, dict) and v.get("avg") + } + rows.append({"solution_id": solution_id, **scores}) + df = pd.DataFrame(rows) + df = df.drop_duplicates() + return df + + +def validate_scores_with_solutions(validation_dir: str, solutions_path: str) -> pd.DataFrame: + scores_df = load_scores(validation_dir=validation_dir) + solutions_list = pd.read_json(solutions_path) + + def add_sol_num(idx: int, block_repr): + return f"{idx} {block_repr}" + + solutions_repr = [ + add_sol_num( + idx, + get_block_repr( + s["single_block_replacement"]["block_idx"], + s["single_block_replacement"]["block_config"], + ), + ) + for idx, s in enumerate(solutions_list) + ] + scores_df["sol_repr"] = scores_df["solution_id"] + " " + scores_df["block_repr"] + + assert len(np.setdiff(scores_df.sol_repr, solutions_repr)) == 0 + + +def delete_scores(validation_dir: str, blocks_regex: str): + scores_df = load_scores(validation_dir=validation_dir) + + score_ids_to_delete = scores_df.query(f'block_repr.str.contains("{blocks_regex}")').solution_id + for score in score_ids_to_delete: + scores_path = Path(validation_dir) / f"solution_{score}.json" + print(f"about to delete: {scores_path}") + os.remove(scores_path) + + +def solution_to_str(block_configs: list[dict[str, Any] | BlockConfig]) -> str: + block_configs = deepcopy(block_configs) + reps = [] + for block_idx, block_config in enumerate(block_configs): + rep = f"block_{block_idx}:".ljust(9) + rep += block_config_to_str(block_config) + reps.append(rep) + rep = "\n".join(reps) + "\n" + return rep + + +def block_config_to_str(block_config: BlockConfig | dict[str, Any] | None) -> str | None: + if block_config is None: + return None + rep = "" + if dataclasses.is_dataclass(block_config): + block_config = dataclasses.asdict(block_config) + for subblock_name in ["attention", "ffn"]: + subblock_config = block_config[subblock_name] + rep += subblock_config_to_str(subblock_config, subblock_name) + return rep + + +def subblock_config_to_str( + subblock_config: FFNConfig | AttentionConfig | dict[str, Any] | None, + subblock_name: None | str = None, +) -> str | None: + if subblock_config is None: + return None + subblock_name = ( + "ffn" + if isinstance(subblock_config, FFNConfig) + else "mamba" + if isinstance(subblock_config, AttentionConfig) and subblock_config.is_mamba + else "attention" + if isinstance(subblock_config, AttentionConfig) + else subblock_name + ) + assert subblock_name is not None, "Must provide subblock_name if subblock_config is a dict." + + if dataclasses.is_dataclass(subblock_config): + subblock_config = dataclasses.asdict(subblock_config) + + if subblock_name == "attention" and subblock_config.get("mamba") is not None: + subblock_name = "mamba" + + if subblock_name == "ffn" and subblock_config.get("moe") is not None: + subblock_name = "moe" + + rep = f" {subblock_name}" + if subblock_config.get("no_op"): + rep += " no_op".ljust(8) + elif subblock_config.get("replace_with_linear"): + rep += " linear".ljust(8) + elif subblock_name == "ffn": + intermediate_size = subblock_config["intermediate_size"] + rep += f" intermediate_{intermediate_size}".ljust(8) + elif subblock_name == "attention": + n_heads_in_group = subblock_config["n_heads_in_group"] + rep += f" gqa_{n_heads_in_group}".ljust(8) + elif subblock_name == "mamba": + mamba_num_heads = subblock_config["mamba"]["num_heads"] + mamba_head_dim = subblock_config["mamba"]["head_dim"] + rep += f" num_heads_{mamba_num_heads} head_dim_{mamba_head_dim}".ljust(8) + elif subblock_name == "moe": + moe_num_local_experts = subblock_config["moe"]["num_local_experts"] + moe_expert_intermediate_dim = subblock_config["moe"]["expert_intermediate_dim"] + shared_expert_intermediate_dim = subblock_config["moe"]["shared_expert_intermediate_dim"] + num_experts_per_tok = subblock_config["moe"]["num_experts_per_tok"] + rep += ( + f" num_experts_{moe_num_local_experts} " + f"expert_intermediate_dim_{moe_expert_intermediate_dim} " + f"shared_expert_intermediate_dim_{shared_expert_intermediate_dim} " + f"num_experts_per_tok_{num_experts_per_tok}" + ).ljust(8) + else: + raise ValueError(f"subblock_config_to_str: unrecognized subblock_name: {subblock_name}.") + + return rep + + +def pareto_frontier( + df: pd.DataFrame, + x: str, + y: str, + x_bigger_is_better: bool = False, # default: smaller x is better + y_bigger_is_better: bool = True, # default: bigger y is better +) -> pd.DataFrame: + """ + Returns the Pareto frontier (non-dominated points) from df based on the criteria: + - For the x-axis, if x_bigger_is_better is True, then higher x values are preferred; + if False, lower x values are preferred. + - For the y-axis, if y_bigger_is_better is True, then higher y values are preferred; + if False, lower y values are preferred. + + A point is considered dominated if there exists another point that is strictly better + in both dimensions. + """ + # Extract the columns as numpy arrays. + X = df[x].to_numpy().copy() # noqa: N806 + Y = df[y].to_numpy().copy() # noqa: N806 + + # Transform the coordinates so that "higher is always better" in both dimensions. + if not x_bigger_is_better: + X = -X # noqa: N806 + if not y_bigger_is_better: + Y = -Y # noqa: N806 + + n_points = len(df) + is_dominated = np.zeros(n_points, dtype=bool) + + # For each point, check if any other point is strictly better in both dimensions. + for i in range(n_points): + domination_mask = (X[i] < X) & (Y[i] < Y) + if np.any(domination_mask): + is_dominated[i] = True + + # Return the DataFrame filtered to only include non-dominated points. + return df[~is_dominated] + + +def non_max_suppression( + df: pd.DataFrame, + x: str, + y: str, + max_x_diff: float, + y_bigger_is_better: bool, +) -> pd.DataFrame: + """ + Filter rows in the DataFrame such that if two rows are within `diff` along the x-axis, + only the one with the preferred y value is kept. + + Parameters: + - df: pandas DataFrame. + - x: Column name for the x-axis values. + - y: Column name for the y-axis values. + - max_x_diff: Distance threshold on the x-axis. + - y_bigger_is_better: + If True, keeps the row with the larger y value (default). + If False, keeps the row with the smaller y value. + + Returns: + - A DataFrame with only the selected rows. + """ + # Sort by y: descending if higher y is preferred, ascending if lower y is preferred. + df_sorted = df.sort_values(y, ascending=not y_bigger_is_better) + kept_indices = [] + + # Iterate over rows in the sorted DataFrame. + for idx, row in df_sorted.iterrows(): + x_val = row[x] + # Skip the row if its x value is within diff of any already-kept row. + if any(abs(x_val - df.loc[kept_idx, x]) < max_x_diff for kept_idx in kept_indices): + continue + kept_indices.append(idx) + + # Return the filtered DataFrame (optionally sorted in the original order). + return df.loc[kept_indices] + + +def soft_pareto_frontier( + df: pd.DataFrame, + x_col: str, + y_col: str, + y_bigger_is_better: bool, + window: int | Literal["auto"] = "auto", + median_diff_factor: float = 2.5, +) -> pd.DataFrame: + """ + Removes low-valued outliers in a sliding window fashion. + Good for getting a soft-pareto frontier, that keeps more than just the very best values. + The auto window is len(df) // 5. + """ + y = df.sort_values(x_col)[y_col] + if not y_bigger_is_better: + y = -y + y_to_keep = rolling_low_values_filter(y, window, median_diff_factor) + indices_to_keep = y_to_keep.index + return df.loc[indices_to_keep] + + +def rolling_low_values_filter( + s: pd.Series, + window: int | Literal["auto"] = "auto", + median_diff_factor: float = 2.5, +) -> pd.Series: + """ + Implements a rolling function that does this given a window: + 1. calculates the max in the window + 2. calculates diff=(max-y) for each y in the window + 3. calculates the median diff + 4. marks values that are smaller than max - 2 * median_diff + 5. removes from the series entries that were marked at least once + """ + s = s[s.notna()] + if window == "auto": + window = len(s) // 5 + + # Create a boolean mask with the same index as the series. + marks = pd.Series(False, index=s.index) + + # Iterate over every possible window + for start in range(len(s) - window + 1): + # Select the current window + window_slice = s.iloc[start : start + window] + + # Step 1: Compute the maximum in the window + max_val = window_slice.max() + + # Step 2: Compute the difference between max and each value in the window + diff = max_val - window_slice + + # Step 3: Compute the median of these differences + median_diff = diff.median() + + # Step 4: Identify values that are smaller than (max - median_diff_factor * median_diff) + threshold = max_val - median_diff_factor * median_diff + mark_window = window_slice < threshold + + # Update the mask: mark an index if it is marked in any window + marks.iloc[start : start + window] |= mark_window + + # Step 5: Remove all entries that were marked at least once + return s[~marks] + + +class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): + def __init__(self, device=None, dtype=None, quantization_mode=None): + """ + Create tensors with given device and dtype and don't run initialization + (but instead use "empty tensors", i.e. uninitialized memory). + + device: `torch.device` to work with + dtype: `torch.dtype` to work with + quantization_mode: optional string, quantization mode to work with, default `None`. + Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU) + `gptq.int4`, `gptq.int8`: GPTQ pre-quantized models + + Example:: + with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): + model = LLaMA(model_config) + model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth"))""" + + assert quantization_mode is None, "Deci: we removed support for lit-llama quantization" + self.quantization_mode = quantization_mode + self.quantized_linear_cls = None + self.device = device + self.dtype = dtype + + def __enter__(self): + if self.quantized_linear_cls is not None: + self.torch_linear_cls = torch.nn.Linear + torch.nn.Linear = self.quantized_linear_cls + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.quantized_linear_cls is not None: + torch.nn.Linear = self.torch_linear_cls + return super().__exit__(exc_type, exc_val, exc_tb) + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if getattr(func, "__module__", None) == "torch.nn.init": + if "tensor" in kwargs: + return kwargs["tensor"] + else: + return args[0] + if ( + self.device is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("device") is None + ): + kwargs["device"] = self.device + if ( + self.dtype is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("dtype") is None + ): + kwargs["dtype"] = self.dtype + return func(*args, **kwargs) + + +# this is taken from torchhacks https://github.com/lernapparat/torchhacks + + +class NotYetLoadedTensor: + def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): + self.metatensor = metatensor + self.archiveinfo = archiveinfo + self.storageinfo = storageinfo + self.rebuild_args = rebuild_args + + @classmethod + def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): + ret = func(*args) + if isinstance(ret, NotYetLoadedTensor): + old_lt = ret._load_tensor + + def _load_tensor(): + t = old_lt() + return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) + + ret._load_tensor = _load_tensor + return ret + return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) + + @classmethod + def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None): + if isinstance(data, NotYetLoadedTensor): + old_lt = data._load_tensor + + def _load_tensor(): + t = old_lt() + return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) + + data._load_tensor = _load_tensor + return data + return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) + + @classmethod + def rebuild_tensor_v2( + cls, + storage, + storage_offset, + size, + stride, + requires_grad, + backward_hooks, + metadata=None, + *, + archiveinfo=None, + ): + rebuild_args = ( + storage_offset, + size, + stride, + requires_grad, + backward_hooks, + metadata, + ) + metatensor = torch._utils._rebuild_tensor_v2( + storage, + storage_offset, + size, + stride, + requires_grad, + backward_hooks, + metadata, + ) + storageinfo = storage.archiveinfo + return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) + + def _load_tensor(self): + name, storage_cls, fn, device, size = self.storageinfo + dtype = self.metatensor.dtype + + uts = ( + self.archiveinfo.zipfile_context.zf.get_storage_from_record( + f"data/{fn}", + size * torch._utils._element_size(dtype), + torch.UntypedStorage, + ) + ._typed_storage() + ._untyped_storage + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + storage = torch.storage.TypedStorage( + wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True + ) + tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) + return tensor + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args] + res = func(*loaded_args, **kwargs) + # gc.collect would be costly here, maybe do it optionally + return res + + def __getattr__(self, name): + # properties + ## TODO: device, is_...?? + ## TODO: mH, mT, H, T, data, imag, real + ## name ??? + if name in { + "dtype", + "grad", + "grad_fn", + "layout", + "names", + "ndim", + "output_nr", + "requires_grad", + "retains_grad", + "shape", + "volatile", + }: + return getattr(self.metatensor, name) + if name in {"size"}: + return getattr(self.metatensor, name) + # materializing with contiguous is needed for quantization + if name in {"contiguous"}: + return getattr(self._load_tensor(), name) + + raise AttributeError(f"{type(self)} does not have {name}") + + def __repr__(self): + return f"NotYetLoadedTensor({self.metatensor!r})" + + +class LazyLoadingUnpickler(pickle.Unpickler): + def __init__(self, file, zipfile_context): + super().__init__(file) + self.zipfile_context = zipfile_context + + def find_class(self, module, name): + res = super().find_class(module, name) + if module == "torch._utils" and name == "_rebuild_tensor_v2": + return functools.partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) + elif module == "torch._tensor" and name == "_rebuild_from_type_v2": + return functools.partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) + elif module == "torch._utils" and name == "_rebuild_parameter": + return functools.partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) + return res + + def persistent_load(self, pid): + name, cls, fn, device, size = pid + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") + s.archiveinfo = pid + return s + + +class LazyLoad: + def __init__(self, fn): + self.zf = torch._C.PyTorchFileReader(str(fn)) + with BytesIO(self.zf.get_record("data.pkl")) as pkl: + mup = LazyLoadingUnpickler(pkl, self) + self.sd = mup.load() + + def __enter__(self): + return self.sd + + def __exit__(self, exc_type, exc_val, exc_tb): + del self.zf # I don't think there is a way to force closing... + self.zf = None + + +# TODO normalize_storage_type not defined +# class SavingProxyForStorage: +# def __init__(self, obj, saver, protocol_version=5): +# self.protocol_version = protocol_version +# self.saver = saver +# if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): +# raise TypeError(f"expected storage, not {type(obj)}") + +# # this logic is taken from PyTorch 2.0+ torch/serialization.py +# if isinstance(obj, torch.storage.TypedStorage): +# # PT upstream wants to deprecate this eventually... +# storage = obj._untyped_storage +# storage_type_str = obj._pickle_storage_type() +# storage_type = getattr(torch, storage_type_str) +# storage_numel = obj._size() +# else: +# storage = obj +# storage_type = normalize_storage_type(type(obj)) +# storage_numel = storage.nbytes() + +# storage_key = saver._write_storage_and_return_key(storage) +# location = torch.serialization.location_tag(storage) + +# self.storage_info = ( +# "storage", +# storage_type, +# storage_key, +# location, +# storage_numel, +# ) + +# def __reduce_ex__(self, protocol_version): +# assert False, "this should be handled with out of band" + +# TODO: SavingProxyForStorage not defined +# class SavingProxyForTensor: +# def __init__(self, tensor, saver, protocol_version=5): +# self.protocol_version = protocol_version +# self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(protocol_version) +# assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" +# storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) +# self.reduce_args = (storage_proxy, *other_reduce_args) + +# def __reduce_ex__(self, protocol_version): +# if protocol_version != self.protocol_version: +# raise RuntimeError( +# f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}" +# ) +# return self.reduce_ret_fn, self.reduce_args + + +# TODO normalize_storage_type not defined +# class IncrementalPyTorchPickler(pickle.Pickler): +# def __init__(self, saver, *args, **kwargs): +# super().__init__(*args, **kwargs) +# self.storage_dtypes = {} +# self.saver = saver +# self.id_map = {} + +# # this logic is taken from PyTorch 2.0+ torch/serialization.py +# def persistent_id(self, obj): +# # FIXME: the docs say that persistent_id should only return a string +# # but torch store returns tuples. This works only in the binary protocol +# # see +# # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects +# # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 +# if isinstance(obj, SavingProxyForStorage): +# return obj.storage_info + +# if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): +# if isinstance(obj, torch.storage.TypedStorage): +# # TODO: Once we decide to break serialization FC, this case +# # can be deleted +# storage = obj._untyped_storage +# storage_dtype = obj.dtype +# storage_type_str = obj._pickle_storage_type() +# storage_type = getattr(torch, storage_type_str) +# storage_numel = obj._size() + +# else: +# storage = obj +# storage_dtype = torch.uint8 +# storage_type = normalize_storage_type(type(obj)) +# storage_numel = storage.nbytes() + +# # If storage is allocated, ensure that any other saved storages +# # pointing to the same data all have the same dtype. If storage is +# # not allocated, don't perform this check +# if storage.data_ptr() != 0: +# if storage.data_ptr() in self.storage_dtypes: +# if storage_dtype != self.storage_dtypes[storage.data_ptr()]: +# raise RuntimeError( +# "Cannot save multiple tensors or storages that " +# "view the same data as different types" +# ) +# else: +# self.storage_dtypes[storage.data_ptr()] = storage_dtype + +# storage_key = self.id_map.get(storage._cdata) +# if storage_key is None: +# storage_key = self.saver._write_storage_and_return_key(storage) +# self.id_map[storage._cdata] = storage_key +# location = torch.serialization.location_tag(storage) + +# return ("storage", storage_type, storage_key, location, storage_numel) + +# return None + + +# TODO IncrementalPyTorchPickler not defined +# class IncrementalSave: +# def __init__(self, name): +# self.name = name +# self.zipfile = torch._C.PyTorchFileWriter(str(name)) +# self.has_saved = False +# self.next_key = 0 + +# def __enter__(self): +# return self + +# def store_early(self, tensor): +# if isinstance(tensor, torch.Tensor): +# return SavingProxyForTensor(tensor, self) +# raise TypeError(f"can only store tensors early, not {type(tensor)}") + +# def save(self, obj): +# if self.has_saved: +# raise RuntimeError("have already saved") +# # Write the pickle data for `obj` +# data_buf = BytesIO() +# pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) +# pickler.dump(obj) +# data_value = data_buf.getvalue() +# self.zipfile.write_record("data.pkl", data_value, len(data_value)) +# self.has_saved = True + +# def _write_storage_and_return_key(self, storage): +# if self.has_saved: +# raise RuntimeError("have already saved") +# key = self.next_key +# self.next_key += 1 +# name = f"data/{key}" +# if storage.device.type != "cpu": +# storage = storage.cpu() +# num_bytes = storage.nbytes() +# self.zipfile.write_record(name, storage.data_ptr(), num_bytes) +# return key + +# def __exit__(self, type, value, traceback): +# self.zipfile.write_end_of_file() + + +if __name__ == "__main__": + Fire() From cb5cf25c0400bd864ef1d9baa975b42b74c018ce Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 5 Nov 2025 20:34:47 +0100 Subject: [PATCH 30/49] Add activation scoring hooks for pruning Signed-off-by: Daniel Korzekwa --- .../activation_hooks/__init__.py | 15 + .../activation_hooks/hooks.py | 1086 +++++++++++++++++ .../activation_hooks/utils.py | 116 ++ .../torch/_compress/tools/validate_model.py | 6 +- 4 files changed, 1221 insertions(+), 2 deletions(-) create mode 100644 modelopt/torch/_compress/activation_scoring/activation_hooks/__init__.py create mode 100644 modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py create mode 100644 modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/__init__.py b/modelopt/torch/_compress/activation_scoring/activation_hooks/__init__.py new file mode 100644 index 000000000..47f1c65a1 --- /dev/null +++ b/modelopt/torch/_compress/activation_scoring/activation_hooks/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py b/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py new file mode 100644 index 000000000..fa81fa1fd --- /dev/null +++ b/modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py @@ -0,0 +1,1086 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import argparse +import gc +import json +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path + +import torch +import torch.nn.functional as F +from omegaconf import DictConfig, OmegaConf +from safetensors.torch import load_file as safe_load_file +from torch import nn + +# BlockConfig used at runtime, not just type hints (lines 680, 790) +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import BlockConfig # noqa: TC001 +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import ( + DeciLMConfig, # noqa: TC001 +) +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( + DeciLMDecoderLayer, + DeciLMRMSNorm, +) +from modelopt.torch._compress.tools.logger import aprint +from modelopt.torch._compress.tools.robust_json import json_dump +from modelopt.torch._compress.tools.runtime import IRuntime + + +def clear_gpu_memory(clear: bool) -> None: + if clear: + gc.collect() + torch.cuda.empty_cache() + + +class ActivationsHook(ABC): + @abstractmethod + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + A hook to be registered in pytorch modules: torch.nn.Module.register_forward_hook() + + Args: + module (nn.Module): + args (tuple[torch.Tensor]): Input of the pytorch module + output (torch.Tensor): Output of the pytorch module + """ + ... + + @abstractmethod + def to_dict(self) -> dict[str, torch.Tensor]: ... + + def save_state(self) -> dict: + """ + Save the internal state of the hook for checkpointing. + + Returns: + dict: State dictionary that can be used to restore the hook's state + """ + # Default implementation - hooks should override this if they have state to save + return {} + + def load_state(self, state_dict: dict) -> None: + """ + Load the internal state of the hook from a checkpoint. + + Args: + state_dict: State dictionary previously returned by save_state() + """ + # Default implementation - hooks should override this if they have state to load + + def get_progress_info(self) -> dict: + """ + Get progress information for this hook (e.g., current iteration, samples processed). + + Returns: + dict: Progress information + """ + # Default implementation - hooks can override to provide progress info + return {} + + @classmethod + def dump_activations_logs( + cls: type["ActivationsHook"], + activation_hooks: dict[str, "ActivationsHook"], + activations_log_dir: Path | str, + args: argparse.Namespace, + runtime: IRuntime | None, + ): + """ + Default implementation for dumping final activation scores logs to disk. + This is called only at the end of scoring to save final results. + """ + + activations_log_dir = Path(activations_log_dir) + activations_log_dir.mkdir(exist_ok=True, parents=True) + rank = runtime.global_rank if runtime is not None else 0 + activations_log_path = activations_log_dir / f"rank_{rank}.pth" + activations_log = { + module_name: hook.to_dict() for module_name, hook in activation_hooks.items() + } + torch.save(activations_log, activations_log_path) + + if rank == 0: + args.activation_hooks_kwargs.pop("model") + json_dump( + OmegaConf.to_container(args, resolve=True) + if isinstance(args, DictConfig) + else vars(args), + activations_log_dir / "args.json", + ) + if runtime is not None: + runtime.wait_for_everyone() # rank 0 will not wait before dumping args.json + + aprint(f"Dumped final activations log to {activations_log_path}") + + @classmethod + def save_hook_states( + cls: type["ActivationsHook"], + activation_hooks: dict[str, "ActivationsHook"], + activations_log_dir: Path | str, + runtime: IRuntime | None, + ): + """ + Save hook states for checkpointing (separate from final results). + This can be called periodically during scoring. + Note: Synchronization should be handled at a higher level to avoid deadlocks. + """ + activations_log_dir = Path(activations_log_dir) + activations_log_dir.mkdir(exist_ok=True, parents=True) + rank = runtime.global_rank if runtime is not None else 0 + + hook_states_path = activations_log_dir / f"hook_states_rank_{rank}.pth" + hook_states = { + module_name: hook.save_state() for module_name, hook in activation_hooks.items() + } + torch.save(hook_states, hook_states_path) + + return hook_states_path + + +class MinitronHook(ActivationsHook): + def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): + weight_matrix = linear_layer.weight.float() + num_channels = linear_layer.in_features + self.agg_channel_activations = torch.zeros( + size=(num_channels,), dtype=torch.float32, device=weight_matrix.device + ) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + activations = args[0] + mean_abs_channel_activations = activations.float().mean( + dim=list(range(activations.ndim - 1)) + ) + self.agg_channel_activations[:] += mean_abs_channel_activations # shape [I] + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + "score": self.agg_channel_activations.cpu(), + } + + +class MlpHook(ActivationsHook): + def __init__(self, mlp: nn.Module, activation_hooks_kwargs: dict): + features = ["l2_norm", "max_norm", "ffn_to_skip_ratio"] + self.features = {f: [] for f in features} + + self.num_batches = 0 + + def __call__( + self, module: nn.Module, mlp_input: tuple[torch.Tensor], mlp_out: torch.Tensor + ) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + # print instance and shape of input and output + print(f"MLP input shape: {mlp_input[0].shape}, output shape: {mlp_out.shape}") + print(f"MLP input type: {type(mlp_input)=}, {type(mlp_input[0])=}") + # return + + mlp_input = mlp_input[1] # unnormed input + l2_norm_skip = torch.linalg.vector_norm(mlp_input, dim=-1) # shape should be (B, T) + l2_norm_ffn_out = torch.linalg.vector_norm(mlp_out, dim=-1) # shape should be (B, T) + max_norm = torch.linalg.vector_norm( + mlp_out, ord=float("inf"), dim=-1 + ) # shape should be (B, T) + + ffn_ratio = l2_norm_ffn_out / (l2_norm_skip + l2_norm_ffn_out) + # difference_norm = torch.linalg.vector_norm(mlp_out - mlp_input, dim=-1) + + self.features["l2_norm"].append(l2_norm_ffn_out.cpu()) + self.features["max_norm"].append(max_norm.cpu()) + self.features["ffn_to_skip_ratio"].append(ffn_ratio.cpu()) + + self.num_batches += 1 + + # calculate_moe_ness + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + # f"{k}_score": v.cpu() / self.num_batches for k, v in self.features.items() + f"{k}_score": torch.cat(v) + for k, v in self.features.items() + } + + +class BlockHook(ActivationsHook): + def __init__(self, block: DeciLMDecoderLayer, activation_hooks_kwargs: dict): + features = ["l2_norm", "max_norm", "ffn_to_skip_ratio"] + self.features = {f: [] for f in features} + + self.num_batches = 0 + + def __call__( + self, + module: nn.Module, + block_input: tuple[torch.Tensor], + block_out: torch.Tensor, + ) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + # # print instance and shape of input and output + print(f"Block input shape: {block_input[0].shape}, output shape: {block_out.shape}") + print(f"Block input type: {type(block_input)=}, {type(block_input[0])=}") + print(f"Block output type: {type(block_out)=}") + print(f"{type(module)=}") + print(f"{block_input[0].device=}") + + block_input = block_input[0] # unnormed input + block_out_before_skip = block_out - block_input + + l2_norm_input = torch.linalg.vector_norm(block_input, dim=-1) # shape should be (B, T) + l2_norm = torch.linalg.vector_norm(block_out_before_skip, dim=-1) # shape should be (B, T) + max_norm = torch.linalg.vector_norm( + block_out_before_skip, ord=float("inf"), dim=-1 + ) # shape should be (B, T) + + diff_ratio = l2_norm / (l2_norm_input + l2_norm) + # difference_norm = torch.linalg.vector_norm(mlp_out - mlp_input, dim=-1) + + self.features["l2_norm"].append(l2_norm.cpu()) + self.features["max_norm"].append(max_norm.cpu()) + self.features["ffn_to_skip_ratio"].append(diff_ratio.cpu()) + + self.num_batches += 1 + + # calculate_moe_ness + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + # f"{k}_score": v.cpu() / self.num_batches for k, v in self.features.items() + f"{k}_score": torch.cat(v) + for k, v in self.features.items() + } + + +class IOCorrelationBlockHook(ActivationsHook): + def __init__(self, block: DeciLMDecoderLayer, activation_hooks_kwargs: dict): + layer_input_descriptors_path = activation_hooks_kwargs.get("layer_input_descriptors_path") + assert layer_input_descriptors_path is not None, ( + "layer_input_descriptors_path must be provided" + ) + assert Path(layer_input_descriptors_path).exists(), ( + f"layer_input_descriptors_path {layer_input_descriptors_path} does not exist" + ) + + self.layer_input_descriptors = safe_load_file(layer_input_descriptors_path)[ + "layer_descriptors" + ] + + features = [ + "spearman_corr", + "cosine_dist", + "pearson_corr", + "data_dep_in_spearman_corr", + "data_dep_in_cosine_dist", + "data_dep_in_pearson_corr", + "data_dep_out_spearman_corr", + "data_dep_out_cosine_dist", + "data_dep_out_pearson_corr", + ] + self.features = {f: [] for f in features} + + self.num_batches = 0 + + @staticmethod + def calculate_metrics(block_delta, layer_input_descriptors): + # Ensure tensors are on the same device + device = block_delta.device + layer_input_descriptors = layer_input_descriptors.to(device) + + B, T, E = block_delta.shape # noqa: N806 + L, _ = layer_input_descriptors.shape # noqa: N806 + + # Normalize for cosine similarity and Pearson correlation in advance + block_delta_norm = (block_delta - block_delta.mean(dim=-1, keepdim=True)) / block_delta.std( + dim=-1, keepdim=True + ) + layer_input_descriptors_norm = ( + layer_input_descriptors - layer_input_descriptors.mean(dim=-1, keepdim=True) + ) / layer_input_descriptors.std(dim=-1, keepdim=True) + + # Precompute cosine similarity (equivalent to Pearson correlation for standardized vectors) + pearson_results = ( + torch.einsum("bte,le->btl", block_delta_norm, layer_input_descriptors_norm) / E + ) + + # Precompute cosine similarity for raw normalized vectors + block_delta_cosine_norm = torch.nn.functional.normalize(block_delta, dim=-1) + layer_input_descriptors_cosine_norm = torch.nn.functional.normalize( + layer_input_descriptors, dim=-1 + ) + cosine_results = torch.einsum( + "bte,le->btl", block_delta_cosine_norm, layer_input_descriptors_cosine_norm + ) + + # Compute Spearman correlation using rank-based operations + block_delta_ranks = block_delta.argsort(dim=-1).float() + layer_input_descriptors_ranks = layer_input_descriptors.argsort(dim=-1).float() + + block_delta_ranks = ( + block_delta_ranks - block_delta_ranks.mean(dim=-1, keepdim=True) + ) / block_delta_ranks.std(dim=-1, keepdim=True) + layer_input_descriptors_ranks = ( + layer_input_descriptors_ranks - layer_input_descriptors_ranks.mean(dim=-1, keepdim=True) + ) / layer_input_descriptors_ranks.std(dim=-1, keepdim=True) + + spearman_results = ( + torch.einsum("bte,le->btl", block_delta_ranks, layer_input_descriptors_ranks) / E + ) + + return spearman_results.cpu(), pearson_results.cpu(), cosine_results.cpu() + + @staticmethod + def calculate_metrics_data_dependent(block_delta, layer_input_descriptors, block_input): + # Ensure tensors are on the same device + device = block_delta.device + layer_input_descriptors = layer_input_descriptors.to(device) + block_input = block_input.to(device) + + B, T, E = block_delta.shape # noqa: N806 + L, _ = layer_input_descriptors.shape # noqa: N806 + + # Compute data-dependent combination + data_dependent_results = torch.einsum("bte,le->btle", block_input, layer_input_descriptors) + + # Normalize the data-dependent descriptors + data_dependent_norm = ( + data_dependent_results - data_dependent_results.mean(dim=-1, keepdim=True) + ) / data_dependent_results.std(dim=-1, keepdim=True) + + # Normalize block_delta for comparison + block_delta_norm = (block_delta - block_delta.mean(dim=-1, keepdim=True)) / block_delta.std( + dim=-1, keepdim=True + ) + + # Compute Pearson correlation + pearson_results = torch.einsum("bte,btle->btl", block_delta_norm, data_dependent_norm) / E + + # Compute cosine similarity + block_delta_cosine_norm = torch.nn.functional.normalize(block_delta, dim=-1) + data_dependent_cosine_norm = torch.nn.functional.normalize(data_dependent_results, dim=-1) + cosine_results = torch.einsum( + "bte,btle->btl", block_delta_cosine_norm, data_dependent_cosine_norm + ) + + # Compute Spearman correlation using rank-based operations + block_delta_ranks = block_delta.argsort(dim=-1).float() + data_dependent_ranks = data_dependent_results.argsort(dim=-1).float() + + block_delta_ranks = ( + block_delta_ranks - block_delta_ranks.mean(dim=-1, keepdim=True) + ) / block_delta_ranks.std(dim=-1, keepdim=True) + data_dependent_ranks = ( + data_dependent_ranks - data_dependent_ranks.mean(dim=-1, keepdim=True) + ) / data_dependent_ranks.std(dim=-1, keepdim=True) + + spearman_results = ( + torch.einsum("bte,btle->btl", block_delta_ranks, data_dependent_ranks) / E + ) + + return spearman_results.cpu(), pearson_results.cpu(), cosine_results.cpu() + + def __call__( + self, + module: nn.Module, + block_input: tuple[torch.Tensor], + block_out: torch.Tensor, + ) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + # # print instance and shape of input and output + print(f"Block input shape: {block_input[0].shape}, output shape: {block_out.shape}") + print(f"Block input type: {type(block_input)=}, {type(block_input[0])=}") + print(f"Block output type: {type(block_out)=}") + print(f"{type(module)=}") + print(f"{block_input[0].device=}") + + block_input = block_input[0] # unnormed input shape B T E + block_delta = block_out - block_input + + spearman_results, pearson_results, cosine_results = self.calculate_metrics( + block_delta, self.layer_input_descriptors + ) + + self.features["spearman_corr"].append(spearman_results) + self.features["cosine_dist"].append(cosine_results) + self.features["pearson_corr"].append(pearson_results) + + # spearman_results, pearson_results, cosine_results = ( + # self.calculate_metrics_data_dependent( + # block_delta, self.layer_input_descriptors, block_input + # ) + # ) + # self.features['data_dep_in_spearman_corr'].append(spearman_results) + # self.features['data_dep_in_cosine_dist'].append(cosine_results) + # self.features['data_dep_in_pearson_corr'].append(pearson_results) + + # spearman_results, pearson_results, cosine_results = ( + # self.calculate_metrics_data_dependent( + # block_delta, self.layer_input_descriptors, block_out + # ) + # ) + # self.features['data_dep_out_spearman_corr'].append(spearman_results) + # self.features['data_dep_out_cosine_dist'].append(cosine_results) + # self.features['data_dep_out_pearson_corr'].append(pearson_results) + + self.num_batches += 1 + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + # f"{k}_score": v.cpu() / self.num_batches for k, v in self.features.items() + f"{k}_score": torch.cat(v) + for k, v in self.features.items() + } + + +class MinitronAbsHook(ActivationsHook): + def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): + weight_matrix = linear_layer.weight.float() + num_channels = linear_layer.in_features + self.agg_channel_activations = torch.zeros( + size=(num_channels,), dtype=torch.float32, device=weight_matrix.device + ) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + activations = args[0] + mean_abs_channel_activations = ( + activations.abs().float().mean(dim=list(range(activations.ndim - 1))) + ) + self.agg_channel_activations[:] += mean_abs_channel_activations # shape [I] + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + "score": self.agg_channel_activations.cpu(), + } + + +class IndependentChannelContributionHook(ActivationsHook): + def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): + weight_matrix = linear_layer.weight.float() + self.weight_norm = torch.linalg.vector_norm(weight_matrix, dim=0) + num_channels = linear_layer.in_features + self.agg_channel_activations = torch.zeros( + size=(num_channels,), dtype=torch.float32, device=weight_matrix.device + ) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + activations = args[0] + mean_abs_channel_activations = ( + activations.abs().float().mean(dim=list(range(activations.ndim - 1))) + ) + self.agg_channel_activations[:] += mean_abs_channel_activations # shape [I] + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + "score": (self.weight_norm * self.agg_channel_activations).cpu(), + "weight_norm": self.weight_norm.cpu(), + "agg_channel_activations": self.agg_channel_activations.cpu(), + } + + def save_state(self) -> dict: + """Save the internal state for checkpointing.""" + return { + "agg_channel_activations": self.agg_channel_activations.cpu().clone(), + "weight_norm": self.weight_norm.cpu().clone(), + } + + def load_state(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.agg_channel_activations = state_dict["agg_channel_activations"].to( + self.agg_channel_activations.device + ) + # weight_norm should be the same as it's derived from the model weights + # but we can verify it matches + expected_weight_norm = state_dict["weight_norm"].to(self.weight_norm.device) + if not torch.allclose(self.weight_norm, expected_weight_norm, rtol=1e-5): + print( + "Warning: weight_norm mismatch during state loading - model weights may have changed" + ) + + +def get_pruning_schedule(num_channels, pruning_iters): + """ + Spending decreases monotonically when num_channels >= pruning_iters. + Intervals between spends increase monotonically when pruning_iters > num_channels. + The budget is fully utilized, and there's spending in the last iteration. + num_channels = 10, pruning_iters = 4 ==> [3, 3, 2, 2] + num_channels = 4, pruning_iters = 10 ==> [0, 1, 0, 1, 0, 0, 1, 0, 0, 1] + """ + if num_channels >= pruning_iters: + # Case when budget is greater than or equal to iterations + q = num_channels // pruning_iters # Base spend per iteration + r = num_channels % pruning_iters # Remainder to distribute + + schedule = [] + for i in range(pruning_iters): + if i < r: + # Assign higher spend to earlier iterations + schedule.append(q + 1) + else: + schedule.append(q) + else: + # Case when iterations are greater than budget + schedule = [0] * pruning_iters + for i in range(1, num_channels + 1): + # Distribute spends at positions where intervals increase monotonically + pos = ((i * pruning_iters) // num_channels) - 1 + schedule[pos] = 1 + return schedule + + +class IterativeChannelContributionHook(ActivationsHook): + def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): + self.weight_matrix = linear_layer.weight + self.num_channels = linear_layer.in_features + self.pruning_iters = activation_hooks_kwargs["validation_full_iters"] + self.clear_gpu_memory = activation_hooks_kwargs.get("clear_gpu_memory", False) + self.curr_iter = 0 + self.pruning_schedule = get_pruning_schedule( + num_channels=self.num_channels, pruning_iters=self.pruning_iters + ) + + self.agg_cont_per_channel = torch.zeros( + size=(self.num_channels,), + dtype=torch.float32, + device=self.weight_matrix.device, + ) + self.pruned_channels = [] + self.calibration_method = activation_hooks_kwargs.get("calibration_method") + self.epsilon = 1e-8 + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + activations = args[0] + n_channels_to_prune = self.pruning_schedule[self.curr_iter] + + curr_activations = activations.clone() # Shape B,T,I + curr_activations[..., self.pruned_channels] = 0 + output_curr = F.linear(input=curr_activations, weight=self.weight_matrix) # Shape B,T,E + + if self.calibration_method is None: + scaling_factor_per_token = torch.ones_like(output[..., 0]) # Shape B,T + elif self.calibration_method == "scale_by_magnitude": + output_norms = torch.linalg.vector_norm(output, dim=-1) # Shape B,T + output_curr_norms = torch.linalg.vector_norm(output_curr, dim=-1) # Shape B,T + scaling_factor_per_token = output_curr_norms / (output_norms + self.epsilon) + del output_curr_norms, output_norms + else: + raise NotImplementedError + del curr_activations + clear_gpu_memory(clear=self.clear_gpu_memory) + + s = scaling_factor_per_token.unsqueeze(-1) * output - output_curr # Shape: (B, T, E) + s_squared_per_token = torch.sum(s**2, dim=-1) # Shape: (B, T) + b = s @ self.weight_matrix # Shape: (B, T, I) + c = torch.sum(self.weight_matrix**2, dim=0) # Shape: (I) + del s, output_curr + clear_gpu_memory(clear=self.clear_gpu_memory) + + contribution_squared = ( + s_squared_per_token.unsqueeze(2) + 2 * activations * b + (activations**2) * c + ) # Shape: (B, T, I) + del s_squared_per_token, b, c, activations + clear_gpu_memory(clear=self.clear_gpu_memory) + + contribution = torch.sqrt(contribution_squared + self.epsilon) # Shape: (B, T, I) + mean_cont_per_channel = torch.mean(contribution, dim=(0, 1)) # Shape: (I) + mean_cont_per_channel[self.pruned_channels] = torch.inf + del contribution, contribution_squared + clear_gpu_memory(clear=self.clear_gpu_memory) + + if n_channels_to_prune == 0: + self.agg_cont_per_channel += mean_cont_per_channel + else: + _, worst_indices = torch.topk(mean_cont_per_channel, n_channels_to_prune, largest=False) + worst_indices_list = worst_indices.tolist() + assert not set(self.pruned_channels).intersection(set(worst_indices_list)) + self.pruned_channels.extend(worst_indices_list) + self.agg_cont_per_channel.zero_() + self.curr_iter += 1 + + def to_dict(self) -> dict[str, torch.Tensor]: + assert self.num_channels == len(self.pruned_channels) + channels_importance_ascending = torch.tensor(self.pruned_channels, dtype=torch.long) + score = torch.empty(self.num_channels, dtype=torch.long) + score[channels_importance_ascending] = torch.arange(self.num_channels, dtype=torch.long) + + return { + "score": score.cpu(), + "channels_importance_ascending": channels_importance_ascending.cpu(), + } + + def save_state(self) -> dict: + """Save the internal state for checkpointing.""" + return { + "curr_iter": self.curr_iter, + "pruned_channels": self.pruned_channels.copy(), + "agg_cont_per_channel": self.agg_cont_per_channel.cpu().clone(), + "num_channels": self.num_channels, + "pruning_iters": self.pruning_iters, + "pruning_schedule": self.pruning_schedule.copy(), + "calibration_method": self.calibration_method, + "epsilon": self.epsilon, + } + + def load_state(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.curr_iter = state_dict["curr_iter"] + self.pruned_channels = state_dict["pruned_channels"].copy() + self.agg_cont_per_channel = state_dict["agg_cont_per_channel"].to(self.weight_matrix.device) + # Verify other parameters match + assert self.num_channels == state_dict["num_channels"], "Channel count mismatch" + assert self.pruning_iters == state_dict["pruning_iters"], "Iteration count mismatch" + assert self.pruning_schedule == state_dict["pruning_schedule"], "Pruning schedule mismatch" + + def get_progress_info(self) -> dict: + """Get progress information.""" + progress = self.curr_iter / self.pruning_iters if self.pruning_iters > 0 else 0.0 + return { + "curr_iter": self.curr_iter, + "total_iters": self.pruning_iters, + "progress": progress, + "pruned_channels_count": len(self.pruned_channels), + "total_channels": self.num_channels, + } + + +class IndependentKvHeadContributionHook(ActivationsHook): + def __init__(self, linear_layer: nn.Linear, activation_hooks_kwargs: dict): + model_config: DeciLMConfig = activation_hooks_kwargs["model"].config + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + + self.optimize_for = activation_hooks_kwargs.get("optimize_for", "memory") + assert self.optimize_for in ["latency", "memory"] + + self.hidden_size = model_config.hidden_size + self.n_heads_in_group = block_config.attention.n_heads_in_group + self.num_q_heads = model_config.num_attention_heads + self.num_kv_heads = self.num_q_heads // self.n_heads_in_group + self.head_dim = getattr(model_config, "head_dim", self.hidden_size // self.num_q_heads) + + self.agg_kv_head_contributions = torch.zeros( + size=(self.num_kv_heads,), + dtype=torch.float32, + device=linear_layer.weight.device, + ) + + # Reshape weight matrix to group by KV heads + self.weight_grouped = linear_layer.weight.view( + self.hidden_size, self.num_kv_heads, self.head_dim * self.n_heads_in_group + ).permute((1, 0, 2)) + # weight_grouped.shape: (kv_heads, hidden_dim, head_dim * n_heads_in_group) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: The linear projection layer + :param args: tuple containing attention output tensor (B, T, num_q_heads * head_dim) + :param output: The projected output (B, T, hidden_dim) + """ + attn_out = args[0] # Shape: (B, T, num_q_heads * head_dim) + batch_size, seq_len, _ = attn_out.shape + + # Reshape attention output to group by KV heads + attn_out_grouped = attn_out.view( + batch_size, + seq_len, + self.num_kv_heads, + self.head_dim * self.n_heads_in_group, + ).unsqueeze(-2) + # attn_out_grouped.shape: (B, T, kv_heads, 1, head_dim * n_heads_in_group) + + if self.optimize_for == "latency": + # Compute contribution per KV head group + # First compute the projection for each KV head group + layer_out_grouped = attn_out_grouped @ self.weight_grouped.transpose(-1, -2) + layer_out_grouped = layer_out_grouped.squeeze(-2) + # layer_out_grouped.shape: (B, T, kv_heads, hidden_dim) + + else: + layer_out_grouped = [] + for i in range(self.num_kv_heads): + _layer_out = attn_out_grouped[:, :, i] @ self.weight_grouped[i].transpose(-1, -2) + layer_out_grouped.append(_layer_out) + layer_out_grouped = torch.cat(layer_out_grouped, dim=2) + + # Compute L2 norm of each group's contribution + contrib_per_kv_head = torch.linalg.vector_norm(layer_out_grouped, dim=-1) + # contrib_per_kv_head.shape: (B, T, kv_heads) + + contrib_per_kv_head = contrib_per_kv_head.mean(dim=(0, 1)) + # contrib_per_kv_head.shape: (kv_heads,) + + # Accumulate contributions + self.agg_kv_head_contributions += contrib_per_kv_head + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + "score": self.agg_kv_head_contributions.cpu(), + } + + +class RouterStatsHook(ActivationsHook): + def __init__(self, router: nn.Linear, activation_hooks_kwargs: dict): + self.r_stats = torch.zeros( + size=(router.out_features,), dtype=torch.int64, device=router.weight.device + ) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + router_logits = output + top_k = 1 + _, router_indices = torch.topk(router_logits, top_k, dim=-1) + ids, counts = router_indices.unique(return_counts=True) + self.r_stats[ids] += counts + # for router_id in router_indices.flatten().cpu().tolist(): + # self.r_stats[router_id] += 1 + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + "r_stats": self.r_stats.cpu(), + } + + def save_state(self) -> dict: + """Save the internal state for checkpointing.""" + return { + "r_stats": self.r_stats.cpu().clone(), + } + + def load_state(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + self.r_stats = state_dict["r_stats"].to(self.r_stats.device) + + +class RankedChoiceVotingHook(ActivationsHook): + def __init__(self, router: nn.Linear, activation_hooks_kwargs: dict): + self.router_argsort: list[torch.Tensor] = [] + block_config: BlockConfig = activation_hooks_kwargs["block_config"] + self.top_k = block_config.ffn.moe.num_experts_per_tok + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + router_logits = output + num_experts = router_logits.shape[-1] + router_argsort = torch.argsort(router_logits, dim=-1, descending=True) + router_argsort = router_argsort.view(-1, num_experts).to(torch.int16).cpu() + self.router_argsort.append(router_argsort) + + def to_dict(self) -> dict[str, torch.Tensor]: + router_argsort = torch.concat(self.router_argsort, dim=0) + num_tokens, num_experts = router_argsort.shape + + expert_ranks = torch.full((num_experts,), -1) + expert_counts_at_pruning_time = {} + + expert_kept_per_iteration: list[list[int]] = [] + expert_counts_per_iteration: list[dict[int, int]] = [] + + for rank in range(num_experts): + ids, counts = router_argsort[:, : self.top_k].unique(return_counts=True) + ids = ids.tolist() + counts = counts.tolist() + expert_counts = dict(zip(ids, counts)) + + expert_kept_per_iteration.append(ids) + expert_counts_per_iteration.append(expert_counts) + + least_popular_expert, min_count = min(expert_counts.items(), key=lambda tup: tup[1]) + + expert_ranks[least_popular_expert] = rank + expert_counts_at_pruning_time[least_popular_expert] = min_count + + router_argsort = router_argsort[router_argsort != least_popular_expert].view( + num_tokens, -1 + ) + + zero_shot_expert_counts = torch.zeros((num_experts,), dtype=torch.long) + for expert_id, expert_counts in expert_counts_per_iteration[0].items(): + zero_shot_expert_counts[expert_id] = expert_counts + zero_shot_expert_ranks = torch.argsort(torch.argsort(zero_shot_expert_counts)) + + return { + # who's more important with Ranked Choice Voting? bigger is better. + "expert_ranks": expert_ranks, + # who's more important with zero-shot voting? bigger is better. + "zero_shot_expert_ranks": zero_shot_expert_ranks, + # how many tokens chose each expert in the iteration where it was removed. + "expert_counts_at_pruning_time": expert_counts_at_pruning_time, + # full expert distribution per pruning iteration. + "expert_counts_per_iteration": expert_counts_per_iteration, + "top_k": self.top_k, # top_k experts per token. + } + + def save_state(self) -> dict: + """Save the internal state for checkpointing.""" + return { + "router_argsort": [tensor.cpu().clone() for tensor in self.router_argsort], + "top_k": self.top_k, + } + + def load_state(self, state_dict: dict) -> None: + """Load the internal state from a checkpoint.""" + # Move tensors back to appropriate device (will be determined when hook is called) + self.router_argsort = [tensor.cpu() for tensor in state_dict["router_argsort"]] + self.top_k = state_dict["top_k"] + + def get_progress_info(self) -> dict: + """Get progress information.""" + return { + "num_batches_processed": len(self.router_argsort), + "total_tokens_processed": sum(tensor.shape[0] for tensor in self.router_argsort) + if self.router_argsort + else 0, + } + + +class RouterNumActiveExpertsStatsHook(ActivationsHook): + def __init__(self, router: nn.Linear, activation_hooks_kwargs: dict): + self.batch_sizes = [8, 16, 32, 64, 128, 256] + self.r_stats = {batch_size: [] for batch_size in self.batch_sizes} + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + router_logits = output + assert len(router_logits.shape) == 3, f"router_logits.shape: {router_logits.shape}" + top_k = 1 + _, router_indices = torch.topk(router_logits, top_k, dim=-1) + + # shuffle router_indices on dim=1 + num_samples = 5 + rand_perm = torch.randperm(router_indices.size(1)) + router_indices_shuffled = router_indices[:, rand_perm] + + for batch_size in self.batch_sizes: + btsz = router_indices.shape[0] + seq_length = router_indices.shape[1] + seq_to_take = batch_size // btsz + # sample a random starting place (random comes from randperm) + starting_place = torch.arange(0, seq_length - seq_to_take + 1, num_samples) + + for start in starting_place[:num_samples]: + num_uniques = ( + router_indices_shuffled[:, start : start + seq_to_take] + .flatten() + .unique() + .numel() + ) + self.r_stats[batch_size].append(num_uniques) + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + "num_active_experts": { + batch_size: torch.tensor(self.r_stats[batch_size]) + for batch_size in self.batch_sizes + }, + } + + +class RouterNumActiveExpertsStatsHookUnshuffled(RouterNumActiveExpertsStatsHook): + def __init__(self, router: nn.Linear, activation_hooks_kwargs: dict): + super().__init__(router, activation_hooks_kwargs) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + router_logits = output + assert len(router_logits.shape) == 3, f"router_logits.shape: {router_logits.shape}" + top_k = 1 + _, router_indices = torch.topk(router_logits, top_k, dim=-1) + + for batch_size in self.batch_sizes: + seq_to_take = batch_size + num_uniques = router_indices[:, -seq_to_take:].flatten().unique().numel() + self.r_stats[batch_size].append(num_uniques) + + +class RouterEntropyHook(ActivationsHook): + def __init__(self, router: nn.Linear, activation_hooks_kwargs: dict): + self.entropy = [] + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + """ + :param module: + :param args: tuple with one tensor entry (B,T,I) + :param output: B,T,E + """ + router_logits = output + assert len(router_logits.shape) == 3, f"router_logits.shape: {router_logits.shape}" + probs = F.softmax(router_logits, dim=-1) + entropy = torch.distributions.Categorical(probs=probs).entropy() + self.entropy.append(entropy.cpu()) + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + "entropy": torch.stack(self.entropy), + "entropy_mean": torch.stack(self.entropy).mean(dim=0), + "entropy_std": torch.stack(self.entropy).std(dim=0), + } + + +class LayerNormlContributionHook(ActivationsHook): + def __init__(self, layernorm_layer: DeciLMRMSNorm, activation_hooks_kwargs: dict): + self.agg_embedding_activations = torch.zeros( + size=(layernorm_layer.weight.shape[0],), + dtype=torch.float32, + device=layernorm_layer.weight.device, + ) + + def __call__(self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor) -> None: + self.agg_embedding_activations += ( + output.abs().float().mean(dim=list(range(output.ndim - 1))) + ) + + @classmethod + def dump_activations_logs( + cls: type["LayerNormlContributionHook"], + activation_hooks: dict[str, "ActivationsHook"], + activations_log_dir: Path | str, + args: argparse.Namespace, + runtime: IRuntime | None, + ): + """ + At the end of the default implementation of dumping activation scores to disc, + save aggregated channel importance results. + """ + + super().dump_activations_logs(activation_hooks, activations_log_dir, args, runtime) + + rank = runtime.global_rank if runtime is not None else 0 + if rank == 0: + LayerNormlContributionHook._save_channel_importance_results( + activation_hooks, activations_log_dir, args + ) + + runtime.wait_for_everyone() + + @staticmethod + def _save_channel_importance_results( + activation_hooks: dict[str, ActivationsHook], + activations_log_dir: Path, + args: argparse.Namespace, + ) -> None: + """ + Save channel importance results from activation hooks. + """ + + # Find all activation files (for multi-rank scenarios) + activations_log_dir = Path(activations_log_dir) + activation_files = list(activations_log_dir.glob("rank_*.pth")) + if not activation_files: + aprint(f"Warning: No activation files found in {activations_log_dir}") + return + + # Load and aggregate activation data from all ranks + all_scores = [] + for activation_file in activation_files: + aprint(f"Loading activations from {activation_file}") + activation_data = torch.load(activation_file, map_location="cpu") + + # Extract scores from the activation data + for module_name, hook_data in activation_data.items(): + if "score" in hook_data: + scores = hook_data["score"] + all_scores.append(scores) + aprint(f"Loaded {len(scores)} channel scores from {module_name}") + + if not all_scores: + aprint("Warning: No valid activation data found") + return + + # Average scores across all ranks and modules + avg_scores = torch.stack(all_scores).mean(dim=0) + aprint(f"Averaged {len(all_scores)} score sets into {len(avg_scores)} channels") + + # Create channel importance ranking (descending order) + ranked_channels = torch.argsort(avg_scores, descending=True).tolist() + + # Create output data structure + timestamp = datetime.now().strftime("%Y_%m_%d__%H_%M_%S") + output_data = { + "model_path": getattr(args, "model_name_or_path", "unknown"), + "dataset_path": getattr(args, "dataset_path", "unknown"), + "experiment_id": getattr(args, "experiment_id", f"experiment_{timestamp}"), + "eval_samples": getattr(args, "eval_samples", 0), + "micro_batch_size": getattr(args, "micro_batch_size", 0), + "timestamp": timestamp, + "total_channels": len(ranked_channels), + "channel_importance_ranking": ranked_channels, + "channel_scores": avg_scores.tolist(), + "score_statistics": { + "min": float(avg_scores.min()), + "max": float(avg_scores.max()), + "mean": float(avg_scores.mean()), + "std": float(avg_scores.std()), + }, + } + + # Save the output + output_path = activations_log_dir / "channel_importance_results.json" + aprint(f"Saving channel importance data to {output_path}") + with open(output_path, "w") as f: + json.dump(output_data, f, indent=2) + + # Print summary statistics + aprint("=== Channel Importance Summary ===") + aprint(f"Total channels: {len(ranked_channels)}") + aprint(f"Top 10 most important channels: {ranked_channels[:10]}") + aprint(f"Bottom 10 least important channels: {ranked_channels[-10:]}") + aprint(f"Score range: {avg_scores.min():.4f} to {avg_scores.max():.4f}") + aprint(f"Score mean: {avg_scores.mean():.4f}") + aprint(f"Score std: {avg_scores.std():.4f}") + + def to_dict(self) -> dict[str, torch.Tensor]: + return { + "score": self.agg_embedding_activations.cpu(), + "channels_importance_ascending": self.agg_embedding_activations.sort()[1].cpu(), + } diff --git a/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py b/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py new file mode 100644 index 000000000..acb633272 --- /dev/null +++ b/modelopt/torch/_compress/activation_scoring/activation_hooks/utils.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import re + +from modelopt.torch._compress.activation_scoring.activation_hooks import hooks +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM + + +def register_activation_hooks( + model: DeciLMForCausalLM, activation_hooks_kwargs: dict +) -> tuple[dict[str, hooks.ActivationsHook], hooks.ActivationsHook]: + hook_class_map = { + "mlp.down_proj": { + "independent": hooks.IndependentChannelContributionHook, + "iterative": hooks.IterativeChannelContributionHook, + }, + "mlp": {"contraction_metrics": hooks.MlpHook}, + "block": {"contraction_metrics": hooks.BlockHook}, + "actual_block": { + "contraction_metrics": hooks.BlockHook, + "io_correlation_metrics": hooks.IOCorrelationBlockHook, + }, + "self_attn.o_proj": { + "independent_kv_head_contribution": hooks.IndependentKvHeadContributionHook, + }, + "router": { + "stats": hooks.RouterStatsHook, + "num_active_experts": hooks.RouterNumActiveExpertsStatsHook, + "num_active_experts_unshuffled": hooks.RouterNumActiveExpertsStatsHookUnshuffled, + "entropy": hooks.RouterEntropyHook, + "ranked_choice_voting": hooks.RankedChoiceVotingHook, + }, + r"regex:experts\.\d+\.down_proj$": { # For MoE + "independent": hooks.IndependentChannelContributionHook, + }, + # TODO: maybe this is too generic, and we should have it specifically for + # input_layernorm and post_attention_layernorm; now it might select qk_norms + "layernorm": { + "layer_norm_contribution": hooks.LayerNormlContributionHook, + }, + } + + activation_hooks = {} + target_layer = activation_hooks_kwargs.get("target_layer", "mlp.c_proj") + + if target_layer.startswith("regex:"): + target_layer_regex = target_layer[len("regex:") :] + pattern = re.compile(target_layer_regex) + + def match_predicate(module_name, module): + return pattern.search(module_name) + else: + + def match_predicate(module_name, module): + return module_name.endswith(target_layer) + + target_layer_hooks_map = hook_class_map.get(target_layer) + if target_layer_hooks_map is None: + raise ValueError(f"no hook classes found for: {target_layer}") + + hook_class = target_layer_hooks_map.get(activation_hooks_kwargs["method"]) + if hook_class is None: + raise ValueError(f"Unknown hook class: {hook_class}") + + if target_layer == "block": + pattern = re.compile(r"^transformer\.h\.\d+$") + + def match_predicate(module_name, module): + return pattern.match(module_name) + + activation_hooks_kwargs["model"] = model + for module_name, module in model.named_modules(): + if match_predicate(module_name, module): + block_config = None + if block_idx_match := re.search(r"\.(\d+)\.", module_name): + block_idx = int(block_idx_match.group(1)) + block_config = model.config.block_configs[block_idx] + curr_activation_hooks_kwargs = { + **activation_hooks_kwargs, + "block_config": block_config, + } + + hook = hook_class(module, curr_activation_hooks_kwargs) + module.register_forward_hook(hook) + activation_hooks[module_name] = hook + + # TODO: CHECK IF WE NEED THIS FOR OTHER CASES THEN SCOUT MOE + # + # if ".experts" in module_name: + # moe_module_name = module_name[:module_name.index(".experts")] + # moe_module = model.get_submodule(moe_module_name) + # if hasattr(moe_module, "router"): + # router_module_name = moe_module_name + ".router" + # if router_module_name not in activation_hooks: + # router = moe_module.router + # router_hook = hooks.RouterStatsHook(router, {}) + # router.register_forward_hook(router_hook) + # activation_hooks[moe_module_name + ".router"] = router_hook + + # Hook state loading is now handled by the checkpoint manager + # if len(activation_hooks) == 0: + # raise ValueError(f"couldn't find any hooks for {target_layer} ") + return activation_hooks, hook_class diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index b852299db..432ca4203 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# mypy: ignore-errors import argparse import textwrap from copy import deepcopy @@ -28,12 +28,14 @@ PreTrainedModel, PreTrainedTokenizerBase, ) -from utils.activation_hooks.utils import register_activation_hooks from utils.data.dataloaders import create_validation_dataloader from utils.parsing import simple_parse_args_string from utils.validate_runtime_pipeline import HiddenStatesAndLMHead, calculate_losses_pipeline from utils.validation import calculate_losses +from modelopt.torch._compress.activation_scoring.activation_hooks.utils import ( + register_activation_hooks, +) from modelopt.torch._compress.tools.checkpoint_utils_hf import load_checkpoint from modelopt.torch._compress.tools.logger import aprint, mprint from modelopt.torch._compress.tools.runtime import IRuntime, NativeDdpRuntime From 6f82a67bf7c998039ac4402580cad172469b0a53 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 6 Nov 2025 08:30:44 +0100 Subject: [PATCH 31/49] make validate_model self-contained Signed-off-by: Daniel Korzekwa --- .../torch/_compress/tools/validate_model.py | 11 +- .../torch/_compress/utils/data/dataloaders.py | 322 +++++++ modelopt/torch/_compress/utils/parsing.py | 446 ++++++++++ .../utils/validate_runtime_pipeline.py | 367 ++++++++ modelopt/torch/_compress/utils/validation.py | 825 ++++++++++++++++++ pyproject.toml | 1 + 6 files changed, 1968 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/_compress/utils/data/dataloaders.py create mode 100644 modelopt/torch/_compress/utils/parsing.py create mode 100644 modelopt/torch/_compress/utils/validate_runtime_pipeline.py create mode 100644 modelopt/torch/_compress/utils/validation.py diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index 432ca4203..53d8ba176 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -28,10 +28,6 @@ PreTrainedModel, PreTrainedTokenizerBase, ) -from utils.data.dataloaders import create_validation_dataloader -from utils.parsing import simple_parse_args_string -from utils.validate_runtime_pipeline import HiddenStatesAndLMHead, calculate_losses_pipeline -from utils.validation import calculate_losses from modelopt.torch._compress.activation_scoring.activation_hooks.utils import ( register_activation_hooks, @@ -40,6 +36,13 @@ from modelopt.torch._compress.tools.logger import aprint, mprint from modelopt.torch._compress.tools.runtime import IRuntime, NativeDdpRuntime from modelopt.torch._compress.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch._compress.utils.data.dataloaders import create_validation_dataloader +from modelopt.torch._compress.utils.parsing import simple_parse_args_string +from modelopt.torch._compress.utils.validate_runtime_pipeline import ( + HiddenStatesAndLMHead, + calculate_losses_pipeline, +) +from modelopt.torch._compress.utils.validation import calculate_losses # #TODO:Import slack from root utils directory # root_path = os.path.join(os.path.dirname(__file__), "..", "..") diff --git a/modelopt/torch/_compress/utils/data/dataloaders.py b/modelopt/torch/_compress/utils/data/dataloaders.py new file mode 100644 index 000000000..315f2c2e8 --- /dev/null +++ b/modelopt/torch/_compress/utils/data/dataloaders.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections.abc import Callable, Mapping, Sequence +from functools import partial +from typing import Protocol, TypeVar + +import datasets +import torch +import torch.distributed +from accelerate import Accelerator +from logger import mprint +from torch.utils.data import DataLoader, Dataset, IterableDataset +from torch.utils.data._utils.collate import collate, default_collate_fn_map +from tqdm import tqdm +from transformers import PreTrainedTokenizerBase +from utils.data.dataset import ConstantLengthDataset + + +def collate_none_fn( + batch, *, collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None +): + return None + + +collate_fn_map_with_none_support = {**default_collate_fn_map, type(None): collate_none_fn} +collate_fn_with_none_support = partial(collate, collate_fn_map=collate_fn_map_with_none_support) + + +class LoadDatasetFn(Protocol): + def __call__( + self, dataset_path: str, content_field: str, keep_in_memory: bool = False + ) -> Mapping[str, Dataset]: ... + + +def load_from_disk_fn( + dataset_path: str, content_field: str, keep_in_memory: bool = False +) -> Mapping[str, Dataset]: + return datasets.load_from_disk(dataset_path, keep_in_memory=keep_in_memory) + + +def load_streaming_fn( + dataset_path: str, content_field: str, keep_in_memory: bool = False +) -> Mapping[str, Dataset]: + dataset = datasets.load_dataset( + dataset_path, + streaming=True, + features=datasets.Features( + { + content_field: datasets.Value(dtype="string"), + } + ), + keep_in_memory=keep_in_memory, + ) + + return dataset + + +def create_train_dataloader( + accelerator: Accelerator, + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name="train", + keep_in_memory: bool = False, + shuffle_train_data_seed: int | None = None, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + varlen: bool = True, +): + mprint(f"\ncreate_train_dataloader on rank {accelerator.process_index}") + if isinstance(dataset, str): + dataset = load_dataset_fn(dataset, content_field, keep_in_memory) + + train_data = dataset[dataset_name] + if shuffle_train_data_seed is not None: + train_data = train_data.shuffle(seed=shuffle_train_data_seed) + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + infinite=True, + seq_length=block_size * micro_batch_size if varlen else block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + # return_cu_seqlens=varlen, + # seqlen_cap=block_size if varlen else None + ) + + train_dataloader = DataLoader( + train_dataset, + batch_size=1 if varlen else micro_batch_size, + pin_memory=True, + collate_fn=collate_fn_with_none_support, + num_workers=os.cpu_count() // 2 // 8, + ) + + return train_dataloader + + +def create_validation_dataloader( + accelerator: Accelerator | None, + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + eval_samples: int | None = None, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name: str = "__auto__", + keep_in_memory: bool = False, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + varlen: bool = True, + shuffle_seed: int | None = None, +): + if accelerator is None: + accelerator = Printer() + + if accelerator.is_main_process: + if isinstance(dataset, str): + dataset = load_dataset_fn(dataset, content_field, keep_in_memory) + + if isinstance(dataset, datasets.Dataset | torch.utils.data.Dataset): + valid_data = dataset + mprint( + "#### Path to specific dataset was given (not DatasetDict), taking it as-is ####" + ) + else: + assert isinstance(dataset, datasets.DatasetDict) + if dataset_name == "__auto__": + val_split_options = [] + for val_key_prefix in ("val", "test"): + if len(val_split_options) == 0: + val_split_options = [ + split + for split in dataset # DatasetDict is dict-like and supports direct iteration + if split.lower().startswith(val_key_prefix) + ] + assert len(val_split_options) == 1, ( + f"Expected exactly one validation split, got {val_split_options=} ({dataset.keys()=})" + ) + val_split = val_split_options[0] + mprint(f"Inferred validation split automatically: '{val_split}'") + else: + val_split = dataset_name + mprint(f"Validation split explicitly chosen: '{val_split}'") + valid_data = dataset[val_split] + + if shuffle_seed is not None: + mprint(f"Shuffling with {shuffle_seed=}") + valid_data = valid_data.shuffle(seed=shuffle_seed) + + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + infinite=False, + seq_length=block_size * micro_batch_size if varlen else block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + # return_cu_seqlens=varlen, + # seqlen_cap=block_size if varlen else None + ) + if varlen and eval_samples is not None: + eval_samples = eval_samples // micro_batch_size + val_offloaded_dataset = realize_dataset_in_memory(valid_dataset, eval_samples) + + valid_data_len = len(val_offloaded_dataset) + mprint(f"num validation examples = {valid_data_len}") + else: + val_offloaded_dataset = None + + if not isinstance(accelerator, Printer): + obj_list = [val_offloaded_dataset] + torch.distributed.broadcast_object_list(obj_list) + val_offloaded_dataset = obj_list[0] + + # let accelerate prepare to handle distributed sampling + val_dataloader = DataLoader( + val_offloaded_dataset, + batch_size=1 if varlen else micro_batch_size, + pin_memory=True, + collate_fn=collate_fn_with_none_support, + ) + + return val_dataloader + + +def realize_dataset_in_memory(dataset: IterableDataset, eval_samples: int | None) -> list[dict]: + tqdm_desc = f"realize_dataset_in_memory({eval_samples=})" + if eval_samples is None: + offloaded_dataset = list(tqdm(dataset, desc=tqdm_desc)) + else: + val_iter = iter(dataset) + offloaded_dataset = [next(val_iter) for _ in tqdm(range(eval_samples), desc=tqdm_desc)] + return offloaded_dataset + + +def create_dataloaders( + accelerator: Accelerator, + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset_path: str, + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + val_micro_batch_size: int | None = None, + eval_samples: int | None = None, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + train_dataset_name: str = "train", + val_dataset_name: str = "__auto__", + disable_validation: bool = False, + keep_in_memory: bool = False, + shuffle_train_data_seed: int | None = None, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + varlen: bool = True, +): + if val_micro_batch_size is None: + val_micro_batch_size = micro_batch_size + + dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory=keep_in_memory) + + train_dataloader = create_train_dataloader( + accelerator, + seed, + tokenizer, + block_size, + dataset, + content_field, + fim_rate, + fim_spm_rate, + micro_batch_size, + load_dataset_fn, + train_dataset_name, + shuffle_train_data_seed=shuffle_train_data_seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + varlen=varlen, + ) + + if not disable_validation: + val_dataloader = create_validation_dataloader( + accelerator, + seed, + tokenizer, + block_size, + dataset, + content_field, + fim_rate, + fim_spm_rate, + val_micro_batch_size, + eval_samples, + load_dataset_fn, + val_dataset_name, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + varlen=varlen, + ) + else: + val_dataloader = None + + return train_dataloader, val_dataloader + + +TensorT = TypeVar("TensorT", bound=torch.Tensor) + + +@torch.no_grad() +def create_padded_tensor( + tensor: TensorT, desired_shape: Sequence[int], padding_value: float = 0 +) -> TensorT: + if tensor.shape == torch.Size(desired_shape): + return tensor + + padded_tensor = torch.full( + desired_shape, fill_value=padding_value, dtype=tensor.dtype, device=tensor.device + ) + indices = torch.where(torch.ones_like(tensor, dtype=torch.bool)) + padded_tensor[indices] = tensor.view(-1) + return padded_tensor + + +class Printer: + is_main_process = True + process_index = None + + @staticmethod + def print(*args, **kwargs) -> None: + print(*args, **kwargs) diff --git a/modelopt/torch/_compress/utils/parsing.py b/modelopt/torch/_compress/utils/parsing.py new file mode 100644 index 000000000..6e880dd99 --- /dev/null +++ b/modelopt/torch/_compress/utils/parsing.py @@ -0,0 +1,446 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +import json +from pathlib import Path +from typing import Any + +import torch +from omegaconf import DictConfig + + +def handle_arg_string(arg): + if arg.lower() == "true": + return True + elif arg.lower() == "false": + return False + elif arg.isnumeric(): + return int(arg) + try: + return float(arg) + except ValueError: + return arg + + +def simple_parse_args_string(args_string): + """ + Parses something like + args1=val1,arg2=val2 + Into a dictionary + """ + if args_string is None: + return {} + args_string = args_string.strip() + if not args_string: + return {} + arg_list = [arg for arg in args_string.split(",") if arg] + args_dict = {k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list]} + return args_dict + + +def parse_json(s: str | None) -> Any: + if s is None: + return None + return json.loads(s) + + +def parse_path(s: str | None) -> Path | None: + if s is None or s == "": + return None + return Path(s) + + +def parse_dtype(dtype_name: str) -> torch.dtype: + dtype = { + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, + "fp32": torch.float32, + "float32": torch.float32, + "fp16": torch.float16, + "float16": torch.float16, + }[dtype_name] + return dtype + + +def get_nested_key(dictionary: dict[str, Any], nested_key: str) -> Any: + """ + If nested_key is "a.b.c" returns dictionary["a"]["b"]["c"] + """ + value = dictionary + for key in nested_key.split("."): + value = value[key] + return value + + +def format_block_configs(config) -> str: + """ + Formats block_configs from a model configuration into a beautiful, readable string. + + Each line represents a layer with attention and FFN configuration. + + Args: + config: PretrainedConfig object containing block_configs + + Returns: + Formatted string with layer configurations + + Example output: + ╭─────────────────────── Model Architecture ────────────────────────╮ + │ Layer 1 │ Attention: no_op │ FFN: mult = 4.95 │ + │ Layer 2 │ Attention: 4 heads in group │ FFN: mult = 4.95 │ + │ Layer 3 │ Attention: 4 heads in group │ FFN: no_op │ + ╰────────────────────────────────────────────────────────────────────╯ + """ + if not hasattr(config, "block_configs") or not config.block_configs: + return "❌ No block configs found" + + lines = [] + + # Header + header = "╭─────────────────────────────────────── Model Architecture ────────────────────────────────────────╮" + lines.append(header) + + # Format each layer + for i, block in enumerate(config.block_configs, 1): + attention_info = _format_attention_config(block.attention) + ffn_info = _format_ffn_config(block.ffn) + + # Create formatted line with proper padding + layer_str = f"Layer {i:2d}" + attention_str = f"Attention: {attention_info}" + ffn_str = f"FFN: {ffn_info}" + + line = f"│ {layer_str:8s} │ {attention_str:30s} │ {ffn_str:18s} │" + lines.append(line) + + # Footer + footer = "╰────────────────────────────────────────────────────────────────────────────────────────────────────╯" + lines.append(footer) + + return "\n".join(lines) + + +def _format_attention_config(attention_config) -> str: + """Format attention configuration for display with visual indicators.""" + if not attention_config: + return "default" + + if attention_config.no_op: + return "❌ no_op" + + n_heads = attention_config.n_heads_in_group + if n_heads is not None: + return f"{n_heads} heads in group" + + if attention_config.replace_with_linear: + return "linear replacement" + + # Check for other attention types + if attention_config.mamba: + return "🐍 mamba" + if attention_config.llama4: + return "🦙 llama4" + + window_length = attention_config.window_length + if window_length is not None: + return f"windowed ({window_length})" + + if attention_config.sparsify: + return "sparse" + + return "default" + + +def _format_ffn_config(ffn_config) -> str: + """Format FFN configuration for display with visual indicators.""" + if not ffn_config: + return "default" + + if ffn_config.no_op: + return "❌ no_op" + + if ffn_config.replace_with_linear: + return "linear" + + ffn_intermediate = ffn_config.intermediate_size + if ffn_intermediate is not None: + return f"ffn_intermediate = {ffn_intermediate}" + + # Check for MoE configuration + moe_config = ffn_config.moe + if moe_config: + return "MoE" + + if ffn_config.sparsify: + return "sparse" + + return "default" + + +def format_global_config(config: DictConfig, title: str = "Global Configuration") -> str: + """ + Pretty prints a global DictConfig with nice formatting and visual indicators. + + Args: + config: DictConfig object to format + title: Title to display at the top of the formatted output + + Returns: + Formatted string with configuration details + + Example output: + ╭─────────────────── Global Configuration ────────────────────╮ + │ Training │ + │ • learning_rate: 1e-4 │ + │ • batch_size: 32 │ + │ • epochs: 100 │ + │ Model │ + │ • hidden_dim: 512 │ + │ • num_layers: 6 │ + │ Data │ + │ • dataset_path: /path/to/data │ + │ • block_size: 2048 │ + ╰──────────────────────────────────────────────────────────────╯ + """ + if not config: + return "❌ No configuration found" + + lines = [] + + # Calculate box width based on title + box_width = max(60, len(title) + 10) + title_padding = (box_width - len(title) - 2) // 2 + + # Header + header = f"\n╭{'─' * (box_width - 2)}╮" + title_line = ( + f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│" + ) + lines.extend([header, title_line]) + + def _format_value(value: Any, indent: int = 0) -> str: + """Format a value with appropriate type indicators.""" + prefix = " " * indent + + if isinstance(value, (bool, int, float)): + return f"{prefix} {value}" + elif isinstance(value, str): + # Show truncated long strings + if len(value) > 50: + return f"{prefix} {value[:47]}..." + return f"{prefix} {value}" + elif isinstance(value, (list, tuple)): + if not value: + return f"{prefix} []" + elif len(value) <= 3: + return f"{prefix} {list(value)}" + else: + return f"{prefix} [{len(value)} items]" + elif value is None: + return f"{prefix} None" + else: + return f"{prefix} {value!s}" + + def _add_config_section(cfg: DictConfig, section_name: str = "", indent: int = 0): + """Recursively add configuration sections.""" + if section_name: + indent_str = " " * indent + section_line = f"│ {indent_str}{section_name}" + # Pad to box width + padding_needed = box_width - len(section_line) - 1 + section_line += " " * padding_needed + "│" + lines.append(section_line) + + for key, value in cfg.items(): + if isinstance(value, DictConfig): + # Nested configuration section + _add_config_section(value, f"{key}", indent + 1) + else: + # Regular key-value pair + indent_str = " " * (indent + 1) + value_str = _format_value(value).replace(" " * 0, "").strip() + line = f"│ {indent_str} {key}: {value_str}" + # Pad to box width + if len(line) >= box_width - 1: + # Truncate long lines + line = line[: box_width - 4] + "..." + padding_needed = box_width - len(line) - 1 + line += " " * padding_needed + "│" + lines.append(line) + + # Add configuration sections + _add_config_section(config) + + # Footer + footer = f"╰{'─' * (box_width - 2)}╯" + lines.append(footer) + + return "\n".join(lines) + + +def format_stitched_losses( + losses_dict: dict[str, float], + best_steps_dict: dict[str, int] | None = None, + best_values_dict: dict[str, float] | None = None, + step_number: int | None = None, + title: str = "Stitched Module Losses", +) -> str: + """ + Pretty prints stitched module losses with comprehensive tracking and visual indicators. + + Args: + losses_dict: Dictionary with block names as keys and current loss values as floats + best_steps_dict: Optional dictionary with block names as keys and best step numbers as values + best_values_dict: Optional dictionary with block names as keys and best loss values as floats + step_number: Optional current step number to include in summary + title: Title to display at the top of the formatted output + + Returns: + Formatted string with loss values in a comprehensive table format + + Example output: + ╭─────────────────── Stitched Module Losses ──────────────────╮ + │ Block │ Loss Value │ Best Step │ Best Value │ Change from avg │ + │───────┼────────────┼───────────┼────────────┼──────────────────│ + │ 00 │ 6.21e-03 │ Step 5 │ 5.95e-03 │ ↑ +2.6e-04 │ + │ 01 │ 5.14e-04 │ Step 12 │ 5.14e-04 │ ↓ -1.2e-04 │ + │ 02 │ 9.84e-05 │ Step 15 │ 9.84e-05 │ ↓ -3.1e-04 │ + ╰──────────────────────────────────────────────────────────────╯ + """ + if not losses_dict: + return "❌ No losses found" + + lines = [] + + # Calculate statistics + loss_values = list(losses_dict.values()) + max_loss = max(loss_values) + min_loss = min(loss_values) + avg_loss = sum(loss_values) / len(loss_values) + + # Calculate box width for new layout (removed Bar column) + box_width = 74 + title_padding = (box_width - len(title) - 2) // 2 + + # Header + header = f"╭{'─' * (box_width - 2)}╮" + title_line = ( + f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│" + ) + separator = ( + f"│ {'Block':<5} │ {'Loss Value':<12} │ {'Best Step':<10} │ " + f"{'Best Value':<12} │ {'Change from avg':<18} │" + ) + divider = f"│{'─' * 7}┼{'─' * 14}┼{'─' * 12}┼{'─' * 14}┼{'─' * 20}│" + + lines.extend([header, title_line, separator, divider]) + + # Format each loss + for block_name, loss_value in losses_dict.items(): + # Format current loss value + loss_str = f"{loss_value:.2e}" + + # Format best step + if best_steps_dict and block_name in best_steps_dict: + best_step_str = f"Step {best_steps_dict[block_name]}" + else: + best_step_str = " --" + + # Format best value + if best_values_dict and block_name in best_values_dict: + best_value = best_values_dict[block_name] + best_value_str = f"{best_value:.2e}" + else: + best_value = loss_value # Assume current is best if no history + best_value_str = f"{best_value:.2e}" + + # Calculate change from average + change_from_avg = loss_value - avg_loss + if abs(change_from_avg) > 1e-8: # Only show if meaningful + change_str = f"{abs(change_from_avg):.1e}" + if change_from_avg > 0: + # Current is above average (worse for loss) + change_display = f"↑ +{change_str}" + else: + # Current is below average (better for loss) + change_display = f"↓ -{change_str}" + else: + # At average value + change_display = "↔ 0.0e+00" + + # Format the line + block_display = block_name.replace("block_", "").zfill(2) + + line = ( + f"│ {block_display:<5} │ {loss_str:<12} │ {best_step_str:<10} │ " + f"{best_value_str:<12} │ {change_display:<18} │" + ) + lines.append(line) + + # Add summary statistics + lines.append(divider) + + # Build summary string with optional step number + summary_parts = [] + if step_number is not None: + summary_parts.append(f"Step {step_number}") + summary_parts.extend([f"Avg={avg_loss:.2e}", f"Max={max_loss:.2e}", f"Min={min_loss:.2e}"]) + + summary_text = ", ".join(summary_parts) + summary = f"│ Summary: {summary_text}" + + # Pad summary to box width + padding_needed = box_width - len(summary) - 1 + summary += " " * padding_needed + "│" + lines.append(summary) + + # Add best step summary if we have best step data + if best_steps_dict and best_values_dict: + # Find the most common best step (modal step) + step_counts = {} + for step in best_steps_dict.values(): + step_counts[step] = step_counts.get(step, 0) + 1 + + if step_counts: + modal_best_step = max(step_counts, key=step_counts.get) + + # Get values at the modal best step for blocks that have it as their best + best_step_values = [] + for block_name, best_step in best_steps_dict.items(): + if best_step == modal_best_step and block_name in best_values_dict: + best_step_values.append(best_values_dict[block_name]) + + if best_step_values: + best_step_avg = sum(best_step_values) / len(best_step_values) + best_step_max = max(best_step_values) + best_step_min = min(best_step_values) + + best_step_summary_text = ( + f"Best: Step {modal_best_step}, Avg={best_step_avg:.2e}, " + f"Max={best_step_max:.2e}, Min={best_step_min:.2e}" + ) + best_step_summary = f"│ {best_step_summary_text}" + + # Pad best step summary to box width + padding_needed = box_width - len(best_step_summary) - 1 + best_step_summary += " " * padding_needed + "│" + lines.append(best_step_summary) + + # Footer + footer = f"╰{'─' * (box_width - 2)}╯" + lines.append(footer) + + return "\n".join(lines) diff --git a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py new file mode 100644 index 000000000..dab7f90ed --- /dev/null +++ b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +from statistics import mean + +import numpy as np +import torch +import torch.distributed +import wandb +from logger import mprint +from puzzle_tools.checkpoint_utils import init_module_with_state_dict +from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM, LMHead +from puzzle_tools.runtime import IRuntime +from sewing_kit import ExternalTarget, InputArgs, ModuleTarget, Needle, RemoteTarget, StitchedModule +from sewing_kit.core import InputReducer +from sewing_kit.utils import distributed_recv_obj, distributed_send_obj, fake_tensor +from torch.utils.data import DataLoader +from tqdm import tqdm +from utils.sharded_checkpoint_utils import DummyBlock +from utils.validation import _organize_outputs, calculate_batch_outputs + + +@torch.no_grad() +def validate_pipeline_inner( + runtime: IRuntime, + stitched_model: StitchedModule, + val_dataloader: DataLoader | None, +) -> float: + if runtime.is_main_process: + assert val_dataloader.batch_size is not None + model_device = next(stitched_model.parameters()).device + + with runtime.autocast(): + stitched_model.eval() + + all_logits: list[torch.Tensor] = [] + all_targets: list[torch.Tensor] = [] + losses: list[float] = [] + + if runtime.is_main_process: + input_ids: torch.Tensor + targets: torch.Tensor + + for i_batch, batch in enumerate(tqdm(val_dataloader)): + input_ids, targets = ( + batch["input_ids"].to(model_device), + batch["targets"].to(model_device), + ) + + if i_batch == 0: + num_batches = len(val_dataloader) + seq_len = input_ids.shape[1] + if torch.distributed.is_initialized(): + torch.distributed.broadcast_object_list([(num_batches, seq_len)]) + + all_targets.append(targets.cpu()) + + output = stitched_model({}, {}, input_ids) + logits = output.captured_outputs.get("model_output") + logits = getattr(logits, "logits", logits) + + if logits is not None: + all_logits.append(logits.cpu()) + + del output, logits + + if len(all_targets) > 0: + distributed_send_obj(all_targets, dst=runtime.world_size - 1) + + else: + obj_list: list[tuple] = [None] + torch.distributed.broadcast_object_list(obj_list) + num_batches, seq_len = obj_list[0] + + fake_input_ids = fake_tensor(1, seq_len, dtype=runtime.dtype) + + for i in range(num_batches): + output = stitched_model({}, {}, fake_input_ids) + logits = output.captured_outputs.get("model_output") + logits = getattr(logits, "logits", logits) + if logits is not None: + all_logits.append(logits.cpu()) + del output, logits + + if len(all_targets) == 0 and runtime.global_rank == runtime.world_size - 1: + all_targets = distributed_recv_obj(src=0) + + torch.distributed.barrier() + + if len(all_logits) > 0: + for logits, targets in zip(all_logits, all_targets): + logits = logits.to("cuda") + targets = targets.to("cuda") + logit_losses = torch.nn.functional.cross_entropy( + logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" + ) + + mean_losses = logit_losses.cpu().mean(dim=-1) + losses.extend(mean_losses.tolist()) + + val_loss = mean(losses) + + if not runtime.is_main_process: + distributed_send_obj(val_loss, dst=0) + elif runtime.is_main_process: + val_loss = distributed_recv_obj() + else: + val_loss = float("nan") + + stitched_model.train() + + loss_list = [val_loss] + torch.distributed.broadcast_object_list(loss_list) + val_loss = loss_list[0] + + return val_loss + + +@torch.no_grad() +def validate_pipeline( + runtime: IRuntime, + stitched_model: StitchedModule, + model_config: DeciLMConfig, + val_dataloader: DataLoader, + iter_num: int | None = None, + max_iters: int | None = None, + model_name: str | None = None, + enable_print: bool = True, + enable_wandb_log: bool = False, + # pad_to_batchsize: bool = True, +) -> float: + if enable_print: + mprint("Validating ...") + + val_loss = validate_pipeline_inner( + runtime=runtime, + stitched_model=stitched_model, + val_dataloader=val_dataloader, + ) + + if runtime.is_main_process: + key = "val/loss" if model_name is None else f"val/{model_name}_loss" + if enable_print: + prefix = "" + if iter_num is not None: + prefix += f"iter {iter_num}" + if max_iters is not None: + prefix += f"/{max_iters}" + prefix += " - " + mprint(f"{prefix}{key}: {val_loss:.4f}") + if enable_wandb_log: + wandb.log({key: val_loss}, step=iter_num) + + runtime.wait_for_everyone() + + return val_loss + + +class HiddenStatesAndLMHead(list): + def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Tensor): + super().__init__(hidden_states) + self.lm_head_weights = lm_head_weights + + +@torch.no_grad() +def calculate_losses_pipeline( + runtime: IRuntime, + stitched_model: StitchedModule | DeciLMForCausalLM, + dataloader: DataLoader | None, + target_hidden_states_per_batch: HiddenStatesAndLMHead | None = None, + return_hidden_states: bool = False, + calculate_full_score_ablations: bool = False, + calc_on_cpu: bool = False, + just_model_forward: bool = False, + checkpoint_manager=None, +) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: + """ + Do model forward on each batch and calculate LM loss. + Optionally also calculate kl_div loss and other metrics from given target_hidden_states_per_batch. + Optionally return hidden states per batch. + Does not support data-parallel. + just_model_forward: skip loss calculation, just forward the model. Useful for activation hooks. + + + Returns: + losses: dict = { + "lm_loss": { + "avg": float, + "per_sample": list[float] + } + more metrics if provided with target_hidden_states_per_batch + } + target_hidden_states_per_batch: list[torch.Tensor], returned if return_hidden_states=True + + """ + if isinstance(stitched_model, DeciLMForCausalLM): + stitched_model = perform_pipeline_stitches(stitched_model, runtime) + + params = list(stitched_model.parameters()) + model_device = params[0].device if params else "cpu" + + # Pre-populate outputs with dummy values for skipped batches + start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0 + if runtime.is_last_process: + outputs = [{"lm_loss": [0.0]}] * start_batch + else: + outputs = None + + if runtime.is_main_process: + all_input_ids, all_targets = zip( + *[(batch["input_ids"], batch["targets"]) for batch in dataloader] + ) + if runtime.world_size > 1: + distributed_send_obj(all_targets, dst=runtime.world_size - 1) + + if runtime.is_last_process: + if runtime.world_size > 1: + all_targets = distributed_recv_obj(src=0) + + lm_head: LMHead = next( + module + for module_name, module in stitched_model.named_modules() + if "lm_head" in module_name + ) + + if target_hidden_states_per_batch is not None: + lm_head_weights = target_hidden_states_per_batch.lm_head_weights + with torch.device(model_device): + target_lm_head = init_module_with_state_dict( + {"weight": lm_head_weights}, LMHead, *lm_head_weights.shape[::-1], bias=False + ) + + if runtime.is_main_process: + num_batches = len(all_input_ids) + seq_len = all_input_ids[0].shape[1] + if runtime.world_size > 1: + torch.distributed.broadcast_object_list([num_batches, seq_len]) + + # Create progress bar with sliced range starting from checkpoint position + desc = ( + f"[rank {runtime.global_rank}] calculate_losses_pipeline(" + f"{(target_hidden_states_per_batch is None)=}, {return_hidden_states=}, {num_batches=})" + ) + progress_bar = tqdm(range(start_batch, num_batches), desc=desc) + else: + obj_list = [None, None] + if runtime.world_size > 1: + torch.distributed.broadcast_object_list(obj_list) + num_batches, seq_len = obj_list + progress_bar = range(start_batch, num_batches) + + stitched_model.eval() + + with runtime.autocast(): + for i_batch in progress_bar: + if runtime.is_main_process: + input_ids = all_input_ids[i_batch].to(model_device) + else: + input_ids = fake_tensor(1, seq_len, dtype=torch.long) + + output = stitched_model({}, {}, input_ids) + + if runtime.is_last_process: + logits = output.captured_outputs.get("model_output") + logits = getattr(logits, "logits", logits) + hidden_states = output.captured_outputs.get("hidden_states") + targets = all_targets[i_batch].to(model_device) + + target_hidden_states = None + target_logits = None + if target_hidden_states_per_batch is not None: + target_hidden_states = target_hidden_states_per_batch[i_batch] + target_hidden_states = target_hidden_states.to(hidden_states.device) + target_logits = target_lm_head(target_hidden_states) + + if just_model_forward: + batch_outputs = {"lm_loss": [-1.0] * len(targets)} + else: + batch_outputs = calculate_batch_outputs( + hidden_states, + target_hidden_states, + logits, + target_logits, + targets, + return_hidden_states, + calculate_full_score_ablations, + calc_on_cpu, + ) + + outputs.append(batch_outputs) + + # Update checkpoint progress periodically + if checkpoint_manager: + checkpoint_manager.update_progress(i_batch + 1, num_batches) + + losses, hidden_states_per_batch = ( + _organize_outputs(outputs) if outputs is not None else (None, None) + ) + + if hidden_states_per_batch is not None: + hidden_states_per_batch = HiddenStatesAndLMHead( + hidden_states_per_batch, lm_head.weight.cpu() + ) + + runtime.wait_for_everyone() + return losses, hidden_states_per_batch + + +def perform_pipeline_stitches( + model: DeciLMForCausalLM, + runtime: IRuntime, +) -> StitchedModule: + target = ModuleTarget("module", model) + stitcher = Needle() + + is_real_block = np.flatnonzero( + [not isinstance(block, DummyBlock) for block in model.model.layers] + ) + first_block, last_block = is_real_block.min(), is_real_block.max() + + if runtime.global_rank != 0: + # receive activations from previous rank + stitcher.stitch( + RemoteTarget(peer_rank=runtime.global_rank - 1).value( + name="activations", adapter=lambda x: InputArgs(x) + ), + target.input( + name=f"model.layers.{first_block}", + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + + if not runtime.is_last_process: + # send activations to next rank + stitcher.stitch( + target.output(f"model.layers.{last_block}"), + RemoteTarget(peer_rank=runtime.global_rank + 1).value(name="activations"), + ) + else: + # register model output + stitcher.stitch( + target.output(name="lm_head"), + ExternalTarget().output("model_output"), + ) + stitcher.stitch( + target.output(name="model.norm"), + ExternalTarget().output("hidden_states"), + ) + + stitched_module = stitcher.knot(ignore_extra_overrides=True) + return stitched_module diff --git a/modelopt/torch/_compress/utils/validation.py b/modelopt/torch/_compress/utils/validation.py new file mode 100644 index 000000000..256c10f6d --- /dev/null +++ b/modelopt/torch/_compress/utils/validation.py @@ -0,0 +1,825 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import functools +import math +from enum import Enum +from statistics import mean + +import numpy as np +import torch +import torch.distributed +import torch.nn.functional as F +import wandb +from accelerate import Accelerator +from puzzle_tools import kd_model +from torch import nn +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper +from typing_extensions import Self +from utils.data.dataloaders import create_padded_tensor + + +@torch.no_grad() +def _validate_single( + accelerator: Accelerator, + model: torch.nn.Module, + rope_cache: torch.Tensor | None, + val_dataloader: DataLoader, + pad_to_batchsize: bool = True, + compute_kl_div: bool = False, + varlen: bool = False, + concat_token_id: int | None = None, +) -> list[float]: + assert val_dataloader.batch_sampler.batch_size is not None + desired_batch_size = val_dataloader.batch_sampler.batch_size + + with accelerator.device, accelerator.autocast(): + model.eval() + + losses: list[float] = [] + + input_ids: torch.LongTensor + targets: torch.LongTensor + is_first_batch = True + for batch in tqdm(val_dataloader, disable=not accelerator.is_main_process): + if is_first_batch: + print( + f"First batch, device {accelerator.device}, input_ids: {batch['input_ids'][:4]}" + ) + is_first_batch = False + input_ids, targets = ( + batch["input_ids"].to(accelerator.device), + batch["targets"].to(accelerator.device), + ) + batch_size = input_ids.size(0) + + if pad_to_batchsize: + input_ids = create_padded_tensor( + input_ids, (desired_batch_size, *input_ids.shape[1:]) + ) + targets = create_padded_tensor(targets, (desired_batch_size, *targets.shape[1:])) + + if rope_cache is not None: + logits = model( + input_ids, rope_cache=rope_cache, varlen=varlen, concat_token_id=concat_token_id + ) + else: + logits = model(input_ids) + + if hasattr(logits, "logits"): # For HF models + logits = logits.logits + + if isinstance(logits, tuple): # For KD + logits, teacher_logits, kd_block_loss, kd_logits_loss = logits + + if compute_kl_div: + # assumes kd_logits_loss has entry for each batch item + batch_losses = kd_logits_loss[:batch_size] + else: + batch_losses = torch.nn.functional.cross_entropy( + logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" + )[:batch_size].mean(dim=-1) + + losses.extend(batch_losses.tolist()) + + model.train() + + return losses + + +@torch.no_grad() +def validate_parallel( + accelerator: Accelerator, + model: torch.nn.Module, + rope_cache: torch.Tensor | None, + val_dataloader: DataLoader, + pad_to_batchsize: bool = True, + compute_kl_div: bool = False, + varlen: bool = False, + concat_token_id: int | None = None, +) -> float: + losses = _validate_single( + accelerator=accelerator, + model=model, + rope_cache=rope_cache, + val_dataloader=val_dataloader, + pad_to_batchsize=pad_to_batchsize, + compute_kl_div=compute_kl_div, + varlen=varlen, + concat_token_id=concat_token_id, + ) + + results = [float("nan")] + if accelerator.is_main_process: + gathered_results = [[float("nan")]] * accelerator.num_processes + torch.distributed.gather_object(losses, gathered_results) + gathered_losses = [l for result in gathered_results for l in result] + results[0] = mean(gathered_losses) + else: + torch.distributed.gather_object(losses) + + torch.distributed.broadcast_object_list(results) + val_loss = results[0] + + return val_loss + + +@torch.no_grad() +def validate( + accelerator: Accelerator, + model: torch.nn.Module, + rope_cache: torch.Tensor | None, + val_dataloader: DataLoader, + iter_num: int | None = None, + max_iters: int | None = None, + model_name: str | None = None, + enable_print: bool = True, + enable_wandb_log: bool = False, + pad_to_batchsize: bool = True, + compute_kl_div: bool = False, + varlen: bool = False, + concat_token_id: int | None = None, +) -> float: + if enable_print: + accelerator.print("Validating ...") + + val_loss = validate_parallel( + accelerator=accelerator, + model=model, + rope_cache=rope_cache, + val_dataloader=val_dataloader, + pad_to_batchsize=pad_to_batchsize, + compute_kl_div=compute_kl_div, + varlen=varlen, + concat_token_id=concat_token_id, + ) + + if accelerator.is_main_process: + key = "val/loss" if model_name is None else f"val/{model_name}_loss" + if enable_print: + prefix = "" + if iter_num is not None: + prefix += f"iter {iter_num}" + if max_iters is not None: + prefix += f"/{max_iters}" + prefix += " - " + accelerator.print(f"{prefix}{key}: {val_loss:.4f}", show_delta=True) + if enable_wandb_log: + wandb.log({key: val_loss}, step=iter_num) + accelerator.wait_for_everyone() + + return val_loss + + +class UnshardedLowMemorySparseTensor: + def __init__(self, x: torch.Tensor): + inds_dtype = self._infer_inds_dtype(x) + x_sparse = x.to_sparse_coo() + self._values = x_sparse.values() + self._indices = x_sparse.indices().to(inds_dtype) + self._size = x_sparse.size() + + @staticmethod + def _infer_inds_dtype(x: torch.Tensor) -> torch.dtype: + max_dim = max(x.shape) + for inds_dtype in [torch.int16, torch.int32, torch.int64]: + if torch.iinfo(inds_dtype).max >= max_dim: + return inds_dtype + + def to_sparse_coo(self) -> torch.Tensor: + return torch.sparse_coo_tensor(values=self._values, indices=self._indices, size=self._size) + + def to_dense(self) -> torch.Tensor: + return self.to_sparse_coo().to_dense() + + def to(self, *args) -> Self: + self._values = self._values.to(*args) + for arg in args: + if isinstance(arg, torch.device) or isinstance(arg, str): + self._indices = self._indices.to(arg) + return self + + +class LowMemorySparseTensor: + _max_sparse_size = torch.iinfo(torch.int32).max + + def __init__(self, x: torch.Tensor): + num_chunks = math.ceil(x.numel() / self._max_sparse_size) + self._chunk_dim = np.argmax(x.shape) + self._chunks = [ + UnshardedLowMemorySparseTensor(chunk) + for chunk in torch.chunk(x, num_chunks, dim=self._chunk_dim) + ] + + def to(self, *args) -> Self: + for chunk in self._chunks: + chunk.to(*args) + return self + + def to_dense(self) -> torch.Tensor: + return torch.concat([chunk.to_dense() for chunk in self._chunks], dim=self._chunk_dim) + + +@torch.no_grad() +def calculate_losses( + model: nn.Module, + dataloader: DataLoader, + target_probs: None = None, + return_probs: bool = False, + checkpoint_manager=None, +) -> tuple[dict[str, dict], None] | tuple[None, None]: + """ + Do model forward on each batch and calculate LM loss. + Works on lit-llama models (single gpu) and huggingface models (can be multi gpu). + Does not support data-parallel. + + ### Anything related to probs and hidden states is not supported currently! ### + calculate_losses() isn't updated according to the major refactor in + calculate_losses_pipeline() regarding hidden states. + + Returns: + outputs = { + "lm_loss": list[float], + "token_accuracy_top_1": list[float], + "token_accuracy_top_5": list[float], + "token_accuracy_top_10": list[float], + } + """ + if (target_probs is not None) or return_probs: + raise NotImplementedError( + "calculate_losses() isn't updated according to the major refactor in " + "calculate_losses_pipeline() regarding hidden states." + ) + + model_device = next(model.parameters()).device + outputs = [] + + try: + num_batches = len(dataloader) + except: + num_batches = None + + # Adjust progress bar for resume + start_batch = checkpoint_manager.current_batch if checkpoint_manager else 0 + progress_bar = tqdm( + enumerate(dataloader), + total=num_batches, + desc=f"calculate_losses({(target_probs is None)=}, {return_probs=})", + ) + if start_batch > 0: + progress_bar.update(start_batch) + + for i_batch, batch in progress_bar: + # Skip batch if resuming from checkpoint + if checkpoint_manager and checkpoint_manager.should_skip_batch(i_batch): + continue + + input_ids = batch["input_ids"].to(model_device) + logits = model(input_ids) + if hasattr(logits, "logits"): + logits = logits.logits + # logits = logits.float() + + targets = batch["targets"].to(model_device) + + batch_outputs = calculate_batch_outputs( + hidden_states=None, + target_hidden_states=None, + logits=logits, + target_logits=None, + targets=targets, + return_hidden_states=False, + calculate_full_score_ablations=False, + calc_on_cpu=False, + ) + outputs.append(batch_outputs) + + # Update checkpoint progress periodically + if checkpoint_manager: + checkpoint_manager.update_progress(i_batch + 1, num_batches) + + losses, _ = _organize_outputs(outputs) + return losses, None + + +def calc_entropy(logits: torch.Tensor) -> torch.Tensor: + """ + Returns per-token entropy given a logits tensor of shape [batch_size x seq_len x vocab_size]. + The output will have shape [batch_size x seq_len]. + """ + # Convert logits to log-probabilities + log_probs = F.log_softmax(logits, dim=-1) # shape: [B x T x V] + + # Compute probabilities from log-probabilities + probs = torch.exp(log_probs) # shape: [B x T x V] + + # Entropy calculation: sum over V of (- p * log p) + ent = -torch.sum(probs * log_probs, dim=-1) # shape: [B x T] + + return ent + + +def confidence_max_softmax(logits: torch.Tensor) -> torch.Tensor: + """ + Returns per-token max-softmax confidence given a logits tensor of shape [batch_size x seq_len x vocab_size]. + The output will have shape [batch_size x seq_len]. + """ + # Compute softmax probabilities + probs = F.softmax(logits, dim=-1) # shape: [B x T x V] + + # Take the maximum probability along the vocabulary dimension + max_confidence = torch.max(probs, dim=-1).values # shape: [B x T] + + return max_confidence + + +def calculate_batch_outputs( + hidden_states: torch.Tensor | None, + target_hidden_states: torch.Tensor | None, + logits: torch.Tensor, + target_logits: torch.Tensor | None, + targets: torch.Tensor, + return_hidden_states: bool, + calculate_full_score_ablations: bool, + calc_on_cpu: bool, +) -> dict: + if calc_on_cpu: + if hidden_states is not None: + hidden_states = hidden_states.cpu() + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.cpu() + if logits is not None: + logits = logits.cpu() + if target_logits is not None: + target_logits = target_logits.cpu() + if targets is not None: + targets = targets.cpu() + + batch_outputs = _calculate_ground_truth_based_scores(logits, targets) + + # _DEBUG_calculate_per_token_entropy(batch_outputs, logits) + + if (target_hidden_states is not None) or (target_logits is not None): + batch_outputs.update( + _calculate_teacher_similarity_scores( + hidden_states, + target_hidden_states, + logits, + target_logits, + calculate_full_score_ablations, + ) + ) + + if return_hidden_states: + batch_outputs["hidden_states_per_batch"] = hidden_states.cpu() + + return batch_outputs + + +def _DEBUG_calculate_per_token_entropy(batch_outputs, logits, i_batch): + import os + + # calculate the per token entropy and per token top p + entropy = calc_entropy(logits).cpu() # .view(-1)#.tolist() + msftm = confidence_max_softmax(logits).cpu() # .view(-1)#.tolist() + teacher_dir = ( + "/lustre/fsw/portfolios/coreai/projects/coreai_nvfm_llm/models/" + "meta-llama/Meta-Llama-3.1-70B-Instruct-new_rope/" + ) + # teacher_dir = ( + # '/lustre/fsw/portfolios/coreai/projects/coreai_nvfm_llm/models/' + # 'meta-llama/Meta-Llama-3.1-405B-Instruct/' + # ) + file_path = f"{teacher_dir}/validation/per_token_stats_{i_batch}.pth" + os.makedirs(os.path.dirname(file_path), exist_ok=True) + torch.save({"entropy": entropy, "max_softmax": msftm}, file_path) + batch_outputs["entropy"] = entropy + batch_outputs["max_softmax"] = msftm + + +def _organize_outputs( + outputs_per_batch: list[dict], +) -> tuple[dict[str, dict], list[torch.Tensor] | None]: + outputs = _concatenate_batch_outputs(outputs_per_batch) + hidden_states_per_batch = outputs.pop("hidden_states_per_batch", None) + losses = { + loss_name: { + "avg": sum(loss_per_sample) / len(loss_per_sample), + "per_sample": loss_per_sample, + } + for loss_name, loss_per_sample in outputs.items() + } + return losses, hidden_states_per_batch + + +def _concatenate_batch_outputs(outputs_per_batch: list[dict]) -> dict[str, list]: + outputs = {} + for output_name in outputs_per_batch[0]: # Regular dict is directly iterable + item_list = [] + for batch_outputs in outputs_per_batch: + batch_items = batch_outputs[output_name] + if isinstance(batch_items, list | tuple): + item_list.extend(batch_items) + else: + item_list.append(batch_items) + outputs[output_name] = item_list + return outputs + + +def _calculate_per_sample_lm_loss( + logits: torch.Tensor, + targets: torch.Tensor, +) -> list[float]: + per_sample_lm_loss = ( + torch.nn.functional.cross_entropy( + logits.transpose(1, 2), targets, ignore_index=-1, reduction="none" + ) + .mean(dim=-1) + .tolist() + ) + return per_sample_lm_loss + + +def _calculate_ground_truth_based_scores( + logits: torch.Tensor, + targets: torch.Tensor, +) -> dict[str, list[float]]: + scores = {"lm_loss": _calculate_per_sample_lm_loss(logits, targets)} + + for top_k in (1, 5, 10): + top_k_predictions = logits.topk(top_k, dim=-1).indices # [b, t, top_k] + is_target_in_predictions = (targets.unsqueeze(-1) == top_k_predictions).any( + dim=-1 + ) # [b, t] + fraction_model_predicted_target = is_target_in_predictions.float().mean(dim=-1) # [b] + scores[f"token_accuracy_top_{top_k}"] = fraction_model_predicted_target.tolist() + + return scores + + +def _calculate_per_sample_kl_div_loss( + logits: torch.Tensor, + batch_target_probs: torch.Tensor | LowMemorySparseTensor, +) -> list[float]: + if isinstance(batch_target_probs, LowMemorySparseTensor): + logits = top_p_top_k(logits) + curr_target_probs = batch_target_probs.to_dense().to(logits.device) # .float() + per_sample_kl_div = [ + F.kl_div( + logits[i_sample].log_softmax(-1), + curr_target_probs[i_sample], + reduction="none", + log_target=False, + ) + .sum(-1) + .mean(-1) + .item() + for i_sample in range(logits.shape[0]) + ] + return per_sample_kl_div + + +def cosine_embedding_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return kd_model.cosine_embedding_loss_batched(hidden_states, target_hidden_states).tolist() + + +def normalized_mse_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + kd_model.normalized_mse_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def mse_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + F.mse_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def mae_loss( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, +) -> list[float]: + return [ + F.l1_loss(hidden_states[i_sample], target_hidden_states[i_sample]).item() + for i_sample in range(hidden_states.shape[0]) + ] + + +def _calculate_teacher_similarity_scores( + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, + logits: torch.Tensor, + target_logits: torch.Tensor, + calculate_full_score_ablations: bool, +) -> dict[str, list[float]]: + """ + hidden_states: [batch, tokens, n_embd] + target_hidden_states: [batch, tokens, n_embd] + logits: [batch, tokens, vocab] + target_logits: [batch, tokens, vocab] + """ + + def calc_per_sample(func, logits, target_probs): + return [ + func(logits=logits[i_sample], target_probs=target_probs[i_sample]) + for i_sample in range(logits.shape[0]) + ] + + score_ablations = {} + + if (target_hidden_states is not None) and (hidden_states.shape == target_hidden_states.shape): + for func in (cosine_embedding_loss, normalized_mse_loss, mse_loss, mae_loss): + score_name = f"{func.__name__}_hidden_states" + score_ablations[score_name] = func(hidden_states, target_hidden_states) + + if target_logits is not None: + for func in (cosine_embedding_loss, normalized_mse_loss, mse_loss, mae_loss): + score_name = f"{func.__name__}_logits" + score_ablations[score_name] = func(logits, target_logits) + + for top_p in (0.99, 0.95, None) if calculate_full_score_ablations else (None,): + transformed_logits = ( + logits if (top_p is None) else top_p_top_k(logits, top_p=top_p, top_k=None) + ) + transformed_target_logits = ( + target_logits + if (top_p is None) + else top_p_top_k(target_logits, top_p=top_p, top_k=None) + ) + target_probs = transformed_target_logits.softmax(-1) + + for func in (kl_div, js_div, tv_dist): + for clip_epsilon in ( + ( + ClipEpsilon.NO_CLIP, + ClipEpsilon.CLIP_NO_RENORMALIZE, + ClipEpsilon.CLIP_RENORMALIZE, + ) + if calculate_full_score_ablations + else (ClipEpsilon.NO_CLIP,) + ): + epsilon_factors = ( + (1.0, 0.1, 0.01) if not clip_epsilon == ClipEpsilon.NO_CLIP else (None,) + ) + + for epsilon_factor in epsilon_factors: + score_name = ( + f"{func.__name__}--top_p_{top_p}--clip_epsilon_{clip_epsilon.name}" + f"--epsilon_factor_{epsilon_factor}" + ) + func_with_args = functools.partial( + func, clip_epsilon=clip_epsilon, epsilon_factor=epsilon_factor + ) + score_ablations[score_name] = calc_per_sample( + func_with_args, transformed_logits, target_probs + ) + if (top_p is None) and (clip_epsilon == ClipEpsilon.NO_CLIP): + short_score_name = func.__name__ + score_ablations[short_score_name] = score_ablations[score_name] + + for top_k in (1, 5, 10): + teacher_greedy_prediction = target_logits.argmax(dim=-1, keepdim=True) # [b,t,1] + student_top_k_predictions = logits.topk(top_k, dim=-1).indices # [b,t,k] + is_teacher_prediction_in_student_predictions = ( + teacher_greedy_prediction == student_top_k_predictions + ).any(dim=-1) # [b,t] + fraction_student_predicted_teacher = ( + is_teacher_prediction_in_student_predictions.float().mean(dim=-1) + ) # [b] + score_ablations[f"greedy_teacher_prediction_in_student_top_{top_k}"] = ( + fraction_student_predicted_teacher.tolist() + ) + + if calculate_full_score_ablations: + for top_p in (0.99, 0.95, 0.50, None): + # student + transformed_logits = logits.clone() + + # teacher + transformed_target_logits = ( + target_logits.clone() + if (top_p is None) + else top_p_top_k(target_logits, top_p=top_p, top_k=None) + ) + + target_probs = transformed_target_logits.softmax(-1) + mask = transformed_target_logits == -1000 + if torch.any(mask): + transformed_logits[mask] = 0 + transformed_target_logits[mask] = 0 + target_probs[mask] = 0 + + for func in (mse_loss, mae_loss): + score_name = f"{func.__name__}_logits_top_p_{top_p}" + score_ablations[score_name] = func( + transformed_logits, transformed_target_logits + ) + + if top_p is not None and top_p > 0.9: + func = kl_div + clip_epsilon = ClipEpsilon.NO_CLIP + score_name = ( + f"{func.__name__}--top_p_{top_p}--clip_epsilon_no_clip_student_unfiltered" + ) + func_with_args = functools.partial( + func, clip_epsilon=clip_epsilon, epsilon_factor=epsilon_factor + ) + score_ablations[score_name] = calc_per_sample( + func_with_args, logits, target_probs + ) + # score_name = f"{func.__name__}_abs--top_p_{top_p}--clip_epsilon_no_clip_student_unfiltered" + # score_ablations[score_name] = [s.abs() for s in score_ablations[score_name]] + + return score_ablations + + +class ClipEpsilon(Enum): + NO_CLIP = "NO_CLIP" + CLIP_RENORMALIZE = "CLIP_RENORMALIZE" + CLIP_NO_RENORMALIZE = "CLIP_NO_RENORMALIZE" + + +def _logits_to_logprobs( + logits: torch.Tensor, clip_epsilon: ClipEpsilon, epsilon_factor: float +) -> torch.Tensor: + """ + logits: [tokens, vocab] + """ + logprobs = logits.log_softmax( + -1 + ) # must normalize logits before clipping otherwise log(1/voacb) means nothing + if clip_epsilon == ClipEpsilon.NO_CLIP: + return logprobs + vocab_size = logprobs.shape[-1] + epsilon = math.log(epsilon_factor * 1 / vocab_size) + logprobs = torch.clip(logprobs, min=epsilon) + if clip_epsilon == ClipEpsilon.CLIP_RENORMALIZE: + logprobs = logprobs.log_softmax( + -1 + ) # we do log_softmax again to retain legitimate distributions + return logprobs + + +def kl_div( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """ + Kullback-Leibler Divergence for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + num_tokens = logits.shape[0] + logprobs = _logits_to_logprobs(logits, clip_epsilon, epsilon_factor) + + _kl_div = ( + F.kl_div(logprobs, target_probs, reduction="sum", log_target=False).item() / num_tokens + ) + return _kl_div + + +def js_div( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """ + Jensen-Shannon Divergence for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + probs = logits.softmax(-1) + mixture_probs = (probs + target_probs) / 2 + mixture_logprobs = mixture_probs.log().clip(min=-1000) + + pred_kl_div = kl_div(mixture_logprobs, probs, clip_epsilon, epsilon_factor) + target_kl_div = kl_div(mixture_logprobs, target_probs, clip_epsilon, epsilon_factor) + _js_div = 0.5 * (pred_kl_div + target_kl_div) + return _js_div + + +def tv_dist( + logits: torch.Tensor, + target_probs: torch.Tensor, + clip_epsilon: ClipEpsilon = ClipEpsilon.NO_CLIP, + epsilon_factor: float = 1.0, +) -> float: + """ + Total Variation Distance (L1-loss) for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + num_tokens, vocab_size = logits.shape + probs = logits.softmax(-1) + + if clip_epsilon != ClipEpsilon.NO_CLIP: + epsilon = epsilon_factor * 1 / vocab_size + probs = probs.clip(min=epsilon) + target_probs = target_probs.clip(min=epsilon) + if clip_epsilon == ClipEpsilon.CLIP_RENORMALIZE: + probs = probs / probs.sum(-1, keepdim=True) + target_probs = target_probs / target_probs.sum(-1, keepdim=True) + + _tv_dist = 0.5 * (probs - target_probs).abs().sum().item() / num_tokens + return _tv_dist + + +DEFAULT_TOP_P = 0.999 +# WestLake model: +# 700 = percentile 0.9 for top_p=0.99 +# 1700 = percentile 0.95 for top_p=0.99 and percentile 0.75 for top_p=0.999 +# For top_p=0.999 and top_k=1700 you take about 75 GB for 2048*8192 tokens +DEFAULT_TOP_K = 1000 + + +def calculate_sparse_probs( + logits: torch.Tensor, + top_p: float | None = DEFAULT_TOP_P, + top_k: int | None = DEFAULT_TOP_K, + verbose: bool = False, +) -> LowMemorySparseTensor: + warped_logits = top_p_top_k(logits, top_p, top_k) + probs = warped_logits.softmax(-1) + sparse_probs = LowMemorySparseTensor(probs) + if True: # Always calculate these metrics (was: if verbose or True:) + probs_unfiltered = logits.softmax(-1) + num_active_per_token = (warped_logits > -1000).sum(-1).float() + prob_density = torch.tensor( + [ + probs_unfiltered[i, j, warped_logits[i, j] > -1000].sum(-1).float() + for j in range(probs_unfiltered.shape[1]) + for i in range(probs_unfiltered.shape[0]) + ] + ) + + print(f""" + Sparsity: + {num_active_per_token.mean().item()=} + {num_active_per_token.quantile(0.25).item()=} + {num_active_per_token.quantile(0.5).item()=} + {num_active_per_token.quantile(0.75).item()=} + {num_active_per_token.quantile(0.9).item()=} + {num_active_per_token.quantile(0.95).item()=} + {num_active_per_token.max().item()=} + + {probs_unfiltered.shape=} + {prob_density.shape=} + {prob_density.mean().item()=} + {prob_density.quantile(0.25).item()=} + {prob_density.quantile(0.5).item()=} + {prob_density.quantile(0.75).item()=} + {prob_density.quantile(0.9).item()=} + {prob_density.quantile(0.95).item()=} + {prob_density.max().item()=} + """) + return sparse_probs + + +def top_p_top_k( + logits: torch.Tensor, + top_p: float | None = DEFAULT_TOP_P, + top_k: int | None = DEFAULT_TOP_K, + filter_value=-1000, +) -> torch.Tensor: + logit_warpers = [] + if top_p is not None: + logit_warpers.append(TopPLogitsWarper(top_p=top_p, filter_value=filter_value)) + if top_k is not None: + logit_warpers.append(TopKLogitsWarper(top_k=top_k, filter_value=filter_value)) + + warped_logits = [] + for sample_logits in logits: + for warper in logit_warpers: + sample_logits = warper(input_ids=None, scores=sample_logits) + warped_logits.append(sample_logits) + warped_logits = torch.stack(warped_logits) + + return warped_logits diff --git a/pyproject.toml b/pyproject.toml index 8ae14292d..29f951c2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ extend-ignore = [ "*/_[a-zA-Z]*" = ["D"] # Private packages (_abc/*.py) or modules (_xyz.py) "*.ipynb" = ["D", "E501"] # Ignore missing docstrings or line length for Jupyter notebooks "modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style +"modelopt/torch/_compress/*" = ["C4", "D", "E", "F", "FURB", "I", "ISC", "N", "PERF", "PGH", "PIE", "PLE", "PLR", "PT", "RUF", "SIM", "TC", "UP", "W"] # TODO:Disabled for now, will enable later, once all puzzletron code is migrated [tool.ruff.lint.pycodestyle] From a87fb7908b1d4450a168a9fefe455e786af3d1fb Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 6 Nov 2025 08:43:44 +0100 Subject: [PATCH 32/49] updage validatete_pipeline to use DeciLMForCausalLM from modelopt Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/tools/sharded_checkpoint_utils.py | 2 +- .../torch/_compress/utils/validate_runtime_pipeline.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py index 08ee7e9d4..4f6a64ee9 100644 --- a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -24,7 +24,7 @@ import torch.nn as nn from huggingface_hub import split_torch_state_dict_into_shards from logger import mprint -from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file diff --git a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py index dab7f90ed..d76572e7f 100644 --- a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py +++ b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py @@ -22,8 +22,11 @@ import wandb from logger import mprint from puzzle_tools.checkpoint_utils import init_module_with_state_dict -from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM, LMHead +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( + DeciLMForCausalLM, + LMHead, +) from puzzle_tools.runtime import IRuntime from sewing_kit import ExternalTarget, InputArgs, ModuleTarget, Needle, RemoteTarget, StitchedModule from sewing_kit.core import InputReducer From b2275215005af076b9e081b4834d7f1612d16430 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 6 Nov 2025 09:08:03 +0100 Subject: [PATCH 33/49] fix imports Signed-off-by: Daniel Korzekwa --- .../score_pruning_activations.py | 2 +- .../tools/sharded_checkpoint_utils.py | 2 +- .../torch/_compress/utils/data/dataloaders.py | 4 +- .../torch/_compress/utils/data/dataset.py | 319 ++++++++++++++++++ modelopt/torch/_compress/utils/utils.py | 6 +- .../utils/validate_runtime_pipeline.py | 8 +- 6 files changed, 332 insertions(+), 9 deletions(-) create mode 100644 modelopt/torch/_compress/utils/data/dataset.py diff --git a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py index 3617bdb1c..ef1e6c273 100644 --- a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py +++ b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py @@ -18,7 +18,7 @@ import hydra import torch from omegaconf import DictConfig -from utils.parsing import format_global_config +from modelopt.torch._compress.utils.parsing import format_global_config from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers from modelopt.torch._compress.tools.logger import mprint diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py index 4f6a64ee9..12e059fa5 100644 --- a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -23,7 +23,7 @@ import torch.distributed import torch.nn as nn from huggingface_hub import split_torch_state_dict_into_shards -from logger import mprint +from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from safetensors import safe_open from safetensors.torch import load_file as safe_load_file diff --git a/modelopt/torch/_compress/utils/data/dataloaders.py b/modelopt/torch/_compress/utils/data/dataloaders.py index 315f2c2e8..4fc856fbb 100644 --- a/modelopt/torch/_compress/utils/data/dataloaders.py +++ b/modelopt/torch/_compress/utils/data/dataloaders.py @@ -22,12 +22,12 @@ import torch import torch.distributed from accelerate import Accelerator -from logger import mprint +from modelopt.torch._compress.tools.logger import mprint from torch.utils.data import DataLoader, Dataset, IterableDataset from torch.utils.data._utils.collate import collate, default_collate_fn_map from tqdm import tqdm from transformers import PreTrainedTokenizerBase -from utils.data.dataset import ConstantLengthDataset +from modelopt.torch._compress.utils.data.dataset import ConstantLengthDataset def collate_none_fn( diff --git a/modelopt/torch/_compress/utils/data/dataset.py b/modelopt/torch/_compress/utils/data/dataset.py new file mode 100644 index 000000000..7398b378c --- /dev/null +++ b/modelopt/torch/_compress/utils/data/dataset.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import functools +from typing import Optional +from typing import Sequence + +import numpy as np +import torch +from torch.utils.data import IterableDataset + +from modelopt.torch._compress.tools.logger import aprint, mprint + +FIM_TOKEN_START = "", "middle>", "suffix>", "pad>"] +CODEGEN_FIM_TOKENS = ["", "<|endoftext|>", ""] + + +class ConstantLengthDataset(IterableDataset): + """ + Iterable dataset that returns constant length chunks of tokens from stream of text files. + Args: + tokenizer (Tokenizer): The processor used for proccessing the data. + dataset (dataset.Dataset): Dataset with text files. + infinite (bool): If True the iterator is reset after dataset reaches end else stops. + seq_length (int): Length of token sequences to return. + num_of_sequences (int): Number of token sequences to keep in buffer. + chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. + fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM. + fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM. + seed (int): Seed for random number generator. + label_shift (bool): Whether to shift labels by 1 or not. + """ + + def __init__( + self, + tokenizer, + dataset, + infinite=False, + seq_length=1024, + num_of_sequences=1024, + chars_per_token=3.6, + content_field="content", + fim_rate=0.5, + fim_spm_rate=0.5, + seed=0, + label_shift=True, + max_sample_length=200_000, + tokens_field="token_ids", + source_datasets_to_discard: Optional[Sequence[str]] = tuple(), + bos_rate: float = 1.0, + return_cu_seqlens: bool = False, + seqlen_cap: Optional[int] = None, + ): + self.tokenizer = tokenizer + self.concat_token_id = tokenizer.eos_token_id + # self.concat_token_id = tokenizer.eos_id # for lit-lamma tokenizer + self.dataset = dataset + self.is_dataset_already_tokenized = tokens_field in self.dataset.column_names + self.seq_length = seq_length + self.infinite = infinite + self.current_size = 0 + if not self.is_dataset_already_tokenized: + self.max_buffer_size = seq_length * chars_per_token * num_of_sequences + self.max_sample_length = max_sample_length + else: + self.max_buffer_size = seq_length * num_of_sequences + # self.max_sample_length = int(max_sample_length / chars_per_token) + self.max_sample_length = max_sample_length # we don't know the exact chars_per_token + self.content_field = content_field + self.tokens_field = tokens_field + self.fim_rate = fim_rate + self.fim_spm_rate = fim_spm_rate + self.seed = seed + self.max_sample_length = max_sample_length + + self.fim_token_ids = get_fim_token_ids(self.tokenizer) + if None in self.fim_token_ids.values() and self.fim_rate > 0: + self.fim_rate = 0 + self.label_shift = label_shift + self.bos_rate = bos_rate + self.source_datasets_to_discard = ( + source_datasets_to_discard if source_datasets_to_discard is not None else tuple() + ) + self.return_cu_seqlens = return_cu_seqlens + self.seqlen_cap = seqlen_cap + self.np_rng = np.random.RandomState(seed=self.seed) + + def __iter__(self) -> dict[str, torch.Tensor]: + iterator = iter(self.dataset) + more_examples = True + while more_examples: + buffer, buffer_len = [], 0 + while True: + if buffer_len >= self.max_buffer_size: + break + try: + sample = next(iterator) + if ( + len(self.source_datasets_to_discard) > 0 + and sample["dataset_name"] in self.source_datasets_to_discard + ): + continue + if not self.is_dataset_already_tokenized: + sample = sample[self.content_field] + if ( + isinstance(sample, list) + and isinstance(sample[0], dict) + and {"content", "role"}.issubset(sample[0]) + ): + if len(sample) > 1: + sample = self.tokenizer.apply_chat_template(sample, tokenize=False) + else: + sample = sample[0]["content"] + else: + sample = sample[self.tokens_field] + sample = sample[: self.max_sample_length] + buffer.append(sample) + buffer_len += len(sample) + except StopIteration: + if self.infinite: + iterator = iter(self.dataset) + else: + more_examples = False + break + + if not self.is_dataset_already_tokenized: + tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] + else: + tokenized_inputs = buffer + + all_token_ids = [] + + for tokenized_input in tokenized_inputs: + if ( + self.bos_rate < 1.0 + and not self.np_rng.binomial(1, self.bos_rate) + and self.tokenizer.bos_token_id is not None + and tokenized_input[0] == self.tokenizer.bos_token_id + ): + tokenized_input = tokenized_input[1:] + # optionally do FIM permutations + if self.fim_rate > 0: + tokenized_input, np_rng = permute( + sample=tokenized_input, + np_rng=self.np_rng, + fim_token_ids=self.fim_token_ids, + fim_rate=self.fim_rate, + fim_spm_rate=self.fim_spm_rate, + truncate_or_pad=False, + ) + + all_token_ids.extend(tokenized_input + [self.concat_token_id]) + + examples = [] + # cuts code snippets in the middle to yield constant length instances + for i in range(0, len(all_token_ids), self.seq_length): + input_ids = all_token_ids[i : i + self.seq_length] + labels = all_token_ids[ + i + int(self.label_shift) : i + int(self.label_shift) + self.seq_length + ] + # ignores last short example in the buffer + if len(labels) == self.seq_length: + examples.append((input_ids, labels)) + + shuffling_indices = self.np_rng.permutation(len(examples)) + examples = [examples[i] for i in shuffling_indices] + + for input_ids, labels in examples: + self.current_size += 1 + input_ids = torch.LongTensor(input_ids) + if self.return_cu_seqlens: + cu_seqlens = self.prepare_cu_seqlens(input_ids) + yield { + "input_ids": input_ids, + "targets": torch.LongTensor(labels), + "cu_seqlens": cu_seqlens, + } + else: + yield { + "input_ids": input_ids, + "targets": torch.LongTensor(labels), + } + + def prepare_cu_seqlens(self, input_ids): + if not self.return_cu_seqlens: + return None + # seqlens is of shape (num_seqs+1,) and with the property that + # the i-th sequnce is input_ids[seqlens[i-1]:seqlens[i]] + cu_seqlens = (input_ids == self.concat_token_id).nonzero().squeeze(-1).int() + 1 + cu_seqlens = torch.cat( + ( + torch.IntTensor([0]), + cu_seqlens, + torch.IntTensor([len(input_ids)]), + ) + ) + if self.seqlen_cap is not None: + i = 1 + while i < len(cu_seqlens): + curr_seqlen = cu_seqlens[i] - cu_seqlens[i - 1] + if curr_seqlen > self.seqlen_cap: + cu_seqlens = torch.cat( + (cu_seqlens[:i], cu_seqlens[[i - 1]] + self.seqlen_cap, cu_seqlens[i:]) + ) + i += 1 + if cu_seqlens[-1] == cu_seqlens[-2]: + cu_seqlens = cu_seqlens[:-1] + return cu_seqlens + + +## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py +def permute( + sample, + np_rng, + fim_token_ids, + fim_rate=0.5, + fim_spm_rate=0.5, + truncate_or_pad=False, +): + """ + Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes: + PSM and SPM (with a probability of fim_spm_rate). + """ + + if np_rng.binomial(1, fim_rate): + boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2)) + boundaries.sort() + + prefix = np.array(sample[: boundaries[0]], dtype=np.int64) + middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64) + suffix = np.array(sample[boundaries[1] :], dtype=np.int64) + + if truncate_or_pad: + raise NotImplementedError + + if "" in fim_token_ids: # use codegen FIM pattern + assert fim_spm_rate == 0 + new_sample = np.concatenate( + [ + prefix, + [fim_token_ids[""]], + suffix, + [fim_token_ids["<|endoftext|>"]], + [fim_token_ids[""]], + [fim_token_ids[""]], + middle, + ] + ) + elif np_rng.binomial(1, fim_spm_rate): + # SPM (variant 2 from FIM paper) + new_sample = np.concatenate( + [ + [fim_token_ids["prefix_tok_id"], fim_token_ids["suffix_tok_id"]], + suffix, + [fim_token_ids["middle_tok_id"]], + prefix, + middle, + ] + ) + else: + # PSM + new_sample = np.concatenate( + [ + [fim_token_ids["prefix_tok_id"]], + prefix, + [fim_token_ids["suffix_tok_id"]], + suffix, + [fim_token_ids["middle_tok_id"]], + middle, + ] + ) + else: + # don't do FIM preproc + new_sample = sample + + return list(new_sample), np_rng + + +# this is expensive so we cache it +@functools.lru_cache(maxsize=None) +def get_fim_token_ids(tokenizer): + # ugly fix for Salesforce/codegen25-7b-multi tokenizer + if hasattr(tokenizer, "encoder"): + search_vocab = tokenizer.encoder._special_tokens + fim_token_ids = {tok: search_vocab.get(tok, None) for tok in CODEGEN_FIM_TOKENS} + else: + search_vocab = tokenizer.vocab + if (FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_STAR + FIM_TOKEN_END_LIST[0]) in search_vocab: + prefix_tok_id, middle_tok_id, suffix_tok_id, pad_tok_id = ( + search_vocab.get(FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_STAR + tok, None) + for tok in FIM_TOKEN_END_LIST + ) + else: + prefix_tok_id, middle_tok_id, suffix_tok_id, pad_tok_id = ( + search_vocab.get(FIM_TOKEN_START + FIM_TOKEN_CONNECTOR_SANTA + tok, None) + for tok in FIM_TOKEN_END_LIST + ) + fim_token_ids = { + "suffix_tok_id": suffix_tok_id, + "prefix_tok_id": prefix_tok_id, + "middle_tok_id": middle_tok_id, + "pad_tok_id": pad_tok_id, + } + return fim_token_ids diff --git a/modelopt/torch/_compress/utils/utils.py b/modelopt/torch/_compress/utils/utils.py index 197e47068..ed3f39eb9 100644 --- a/modelopt/torch/_compress/utils/utils.py +++ b/modelopt/torch/_compress/utils/utils.py @@ -41,7 +41,11 @@ import pandas as pd import torch from fire import Fire -from puzzle_tools.deci_lm_hf_code.block_config import AttentionConfig, BlockConfig, FFNConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) from tqdm import tqdm diff --git a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py index d76572e7f..3c735b431 100644 --- a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py +++ b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py @@ -20,20 +20,20 @@ import torch import torch.distributed import wandb -from logger import mprint -from puzzle_tools.checkpoint_utils import init_module_with_state_dict +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.checkpoint_utils import init_module_with_state_dict from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMForCausalLM, LMHead, ) -from puzzle_tools.runtime import IRuntime +from modelopt.torch._compress.tools.runtime import IRuntime from sewing_kit import ExternalTarget, InputArgs, ModuleTarget, Needle, RemoteTarget, StitchedModule from sewing_kit.core import InputReducer from sewing_kit.utils import distributed_recv_obj, distributed_send_obj, fake_tensor from torch.utils.data import DataLoader from tqdm import tqdm -from utils.sharded_checkpoint_utils import DummyBlock +from modelopt.torch._compress.tools.sharded_checkpoint_utils import DummyBlock from utils.validation import _organize_outputs, calculate_batch_outputs From ca7ab3ff33884e5786ccc939d8c4b1a352d3898a Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 6 Nov 2025 09:48:57 +0100 Subject: [PATCH 34/49] add sewing_kit Signed-off-by: Daniel Korzekwa --- .../torch/_compress/sewing_kit/__init__.py | 34 + modelopt/torch/_compress/sewing_kit/common.py | 19 + modelopt/torch/_compress/sewing_kit/core.py | 881 ++++++++++++++++++ .../_compress/sewing_kit/passage/__init__.py | 28 + .../_compress/sewing_kit/passage/core.py | 462 +++++++++ .../sewing_kit/passage/recipes/__init__.py | 15 + modelopt/torch/_compress/sewing_kit/utils.py | 506 ++++++++++ 7 files changed, 1945 insertions(+) create mode 100644 modelopt/torch/_compress/sewing_kit/__init__.py create mode 100644 modelopt/torch/_compress/sewing_kit/common.py create mode 100644 modelopt/torch/_compress/sewing_kit/core.py create mode 100644 modelopt/torch/_compress/sewing_kit/passage/__init__.py create mode 100644 modelopt/torch/_compress/sewing_kit/passage/core.py create mode 100644 modelopt/torch/_compress/sewing_kit/passage/recipes/__init__.py create mode 100644 modelopt/torch/_compress/sewing_kit/utils.py diff --git a/modelopt/torch/_compress/sewing_kit/__init__.py b/modelopt/torch/_compress/sewing_kit/__init__.py new file mode 100644 index 000000000..6df9f8afa --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/__init__.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +from .core import ( + Needle, + KnotException, + LoopFoundException, + InputsLoopFoundException, + MultipleExternalNodesException, + OnlyInternalNodesException, + OutputsLoopFoundException, + ExternalTarget, + ModuleTarget, + ConstantTarget, + FunctionTarget, + RemoteTarget, + StitchedModule, + StitchedModuleException, + CantResolveNodeDependenciesException, + StitchedModuleOutput, +) +from .passage import always_false_predicate, always_true_predicate, InputArgs diff --git a/modelopt/torch/_compress/sewing_kit/common.py b/modelopt/torch/_compress/sewing_kit/common.py new file mode 100644 index 000000000..5bc573232 --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/common.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +logger = logging.getLogger("sewing_kit") +logger.setLevel(logging.WARN) diff --git a/modelopt/torch/_compress/sewing_kit/core.py b/modelopt/torch/_compress/sewing_kit/core.py new file mode 100644 index 000000000..550c1298c --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/core.py @@ -0,0 +1,881 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +from __future__ import annotations +from abc import ABC +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Union +from typing_extensions import override + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + +import torch +import torch.distributed +import torch.nn as nn + +from .utils import distributed_isend_obj, distributed_recv_obj, dynamo_skip +from .passage import ( + Passage, + InputArgs, + OutputValue, + Predicate, + always_false_predicate, + PassageInputAdapter, + PassageOutputAdapter, + PassageInputOverrides, + PassageOutputOverrides, +) + + +InputAdapter = Callable[[InputArgs], InputArgs] +OutputAdapter = Callable[..., OutputValue] + + +def default_input_adapter_fn(input_values: InputArgs) -> InputArgs: + return input_values + + +def default_output_adapter_fn(v: OutputValue) -> OutputValue: + return v + + +@dataclass +class IOReducer: + pass + + +def default_input_reducer_fn(acc: InputArgs, input_override: InputArgs, *args): + return acc + input_override + + +@dataclass +class InputReducer(IOReducer): + reducer_fn: Callable[[InputArgs, InputArgs, InputArgs, int, list[InputArgs]], InputArgs] = ( + default_input_reducer_fn + ) + + def __call__( + self, + acc: InputArgs, + input_override: InputArgs, + original_input: InputArgs, + index: int, + all_input_overrides: list[InputArgs], + ) -> InputArgs: + result = self.reducer_fn(acc, input_override, original_input, index, all_input_overrides) + return result + + @classmethod + def default(cls) -> InputReducer: + return InputReducer() + + +def default_output_reducer_fn(acc: OutputValue, input_override: OutputValue, *args): + return input_override + + +@dataclass +class OutputReducer(IOReducer): + reducer_fn: Callable[ + [OutputValue, OutputValue, Optional[OutputValue], int, list[OutputValue]], OutputValue + ] = default_output_reducer_fn + requires_original_output: bool = False + + def __call__( + self, + acc: OutputValue, + output_override: OutputValue, + original_output: Optional[OutputValue], + index: int, + all_output_overrides: list[OutputValue], + ) -> InputArgs: + result = self.reducer_fn(acc, output_override, original_output, index, all_output_overrides) + return result + + @classmethod + def default(cls) -> OutputReducer: + return OutputReducer() + + +class Singleton(type): + _instances = {} + + @override + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +@dataclass +class Target: + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class TargetWithInput(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def input( + self, + adapter: InputAdapter = default_input_adapter_fn, + reducer: InputReducer = InputReducer.default(), + ) -> InputDescriptor: + result = InputDescriptor(self, input_name="", input_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithNamedInputs(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def input( + self, + name: str, + adapter: InputAdapter = default_input_adapter_fn, + reducer: InputReducer = InputReducer.default(), + ) -> InputDescriptor: + result = InputDescriptor(self, input_name=name, input_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithOutput(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def output( + self, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name="", output_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class TargetWithNamedOutputs(Target): + @override + def __hash__(self) -> int: + return super().__hash__() + + def output( + self, + name: str, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name=name, output_adapter=adapter, reducer=reducer) + return result + + +@dataclass +class ExternalTarget(TargetWithNamedInputs, TargetWithNamedOutputs, metaclass=Singleton): + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class ConstantTarget(TargetWithOutput): + name: str + value: Any + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class FunctionTarget(TargetWithInput, TargetWithOutput): + name: str + function: Callable[..., Any] + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class ModuleTarget(TargetWithNamedInputs, TargetWithNamedOutputs): + name: str + module: nn.Module + + @override + def __str__(self) -> str: + return f"ModuleTarget({self.name})" + + @override + def __repr__(self) -> str: + return str(self) + + @override + def __hash__(self) -> int: + return super().__hash__() + + +@dataclass +class RemoteTarget(Target): + peer_rank: Union[int, Sequence[int]] + process_group: Optional[torch.distributed.ProcessGroup] = None + blocking: bool = True + + @override + def __hash__(self) -> int: + return super().__hash__() + + def value( + self, + name: str, + adapter: OutputAdapter = default_output_adapter_fn, + reducer: OutputReducer = OutputReducer.default(), + ) -> OutputDescriptor: + result = OutputDescriptor(self, output_name=name, output_adapter=adapter, reducer=reducer) + return result + + +@dataclass(frozen=True, eq=True) +class RemoteDataDescriptor(ABC): + key: str + + +@dataclass(frozen=True, eq=True) +class RemoteTensorDataDescriptor(RemoteDataDescriptor): + device: Literal["cuda", "cpu"] + dtype: torch.dtype + shape: torch.Size + + +@dataclass(frozen=True, eq=True) +class RemotePythonDataDescriptor(RemoteDataDescriptor): + value: Any + + +@dataclass +class Node: + target: Target + stitches_to: list[StitchDescriptor] = field(default_factory=list) + stitches_from: list[StitchDescriptor] = field(default_factory=list) + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class InputDescriptor: + target: Target + input_name: str = "" + input_adapter: InputAdapter = field(default=default_input_adapter_fn) + reducer: InputReducer = field(default_factory=InputReducer.default) + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class OutputDescriptor: + target: Target + output_name: str = "" + output_adapter: OutputAdapter = field(default=default_output_adapter_fn) + reducer: OutputReducer = field(default_factory=OutputReducer.default) + + @override + def __hash__(self) -> int: + return id(self) + + +IODescriptor = Union[InputDescriptor, OutputDescriptor] + + +@dataclass +class StitchDescriptor: + source_descriptor: IODescriptor + destination_descriptor: IODescriptor + + @override + def __hash__(self) -> int: + return id(self) + + +@dataclass +class StitchedModuleOutput: + captured_inputs: dict[str, InputArgs] + captured_outputs: dict[str, Any] + + +class StitchedModuleException(Exception): + pass + + +class CantResolveNodeDependenciesException(StitchedModuleException): + pass + + +class StitchedModule(nn.Module): + def __init__( + self, + nodes: dict[Target, Node], + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit=True, + ignore_extra_overrides=False, + ) -> None: + super().__init__() + self.nodes = nodes + self.ignore_extra_overrides = ignore_extra_overrides + external_nodes = [n for n in nodes.values() if isinstance(n.target, ExternalTarget)] + remote_nodes = [n for n in nodes.values() if isinstance(n.target, RemoteTarget)] + assert len(external_nodes) <= 1 + assert len(remote_nodes) + len(external_nodes) > 0 + self.external_node = external_nodes[0] if len(external_nodes) > 0 else None + self.internal_nodes = [ + n for n in nodes.values() if not isinstance(n.target, ExternalTarget) + ] + self.values_from_node: dict[Node, dict[IODescriptor, Any]] = defaultdict(dict) + self.values_to_node: dict[Node, dict[IODescriptor, Any]] = defaultdict(dict) + + self.node_passages: dict[Node, Passage] = { + node: Passage.create( + module=node.target.module, + inputs_to_capture=set( + s.source_descriptor.input_name + for s in node.stitches_from + if isinstance(s.source_descriptor, InputDescriptor) + ), + outputs_to_capture=set( + s.source_descriptor.output_name + for s in node.stitches_from + if isinstance(s.source_descriptor, OutputDescriptor) + ), + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + name=getattr(node.target, "name", None), + ) + for node in self.internal_nodes + if isinstance(node.target, ModuleTarget) + } + + self.passage_modules = nn.ModuleDict( + { + f"node_{node_index}": self.node_passages[node] + for node_index, node in enumerate(nodes.values()) + if node in self.node_passages + } + ) + self.adapter_modules = nn.ModuleDict( + { + f"node_{node_index}__stitch_{stitch_index}__{descriptor_name}": adapter + for node_index, node in enumerate(nodes.values()) + for stitch_index, stitch in enumerate(node.stitches_from + node.stitches_to) + for descriptor_name, descriptor in ( + ("source", stitch.source_descriptor), + ("destination", stitch.destination_descriptor), + ) + for adapter in [ + descriptor.input_adapter + if isinstance(descriptor, InputDescriptor) + else descriptor.output_adapter + ] + if isinstance(adapter, nn.Module) + } + ) + + def create_input_overrides( + self, values_to_node: dict[IODescriptor, Any] + ) -> PassageInputOverrides: + input_descriptors_by_group = defaultdict[str, list[InputDescriptor]](list) + for io_descriptor in values_to_node.keys(): + if isinstance(io_descriptor, InputDescriptor): + input_descriptors_by_group[io_descriptor.input_name].append(io_descriptor) + + input_overrides = PassageInputOverrides() + for group, input_descriptors in input_descriptors_by_group.items(): + reducers = [d.reducer for d in input_descriptors] + + def create_reducer(input_descriptors=input_descriptors, reducers=reducers): + inputs = [values_to_node[d] for d in input_descriptors] + + def reducer_fn( + original_input: InputArgs, + module_name: Optional[str], + module: Optional[nn.Module], + ) -> InputArgs: + acc = InputArgs() + for i, (input_, reducer) in enumerate(zip(inputs, reducers)): + acc = reducer(acc, input_, original_input, i, inputs) + return acc + + return reducer_fn + + input_override = PassageInputAdapter(create_reducer()) + input_overrides[group] = input_override + + return input_overrides + + def create_output_overrides( + self, values_to_node: dict[IODescriptor, Any] + ) -> PassageOutputOverrides: + output_descriptors_by_group = defaultdict[str, list[OutputDescriptor]](list) + for io_descriptor in values_to_node.keys(): + if isinstance(io_descriptor, OutputDescriptor): + output_descriptors_by_group[io_descriptor.output_name].append(io_descriptor) + + output_overrides = PassageOutputOverrides() + for group, output_descriptors in output_descriptors_by_group.items(): + reducers = [d.reducer for d in output_descriptors] + requires_original_output = any(r.requires_original_output for r in reducers) + + def create_reducer(reducers=reducers): + outputs = [values_to_node[d] for d in output_descriptors] + + def reducer_fn( + original_output: Optional[OutputValue], + module_name: Optional[str], + module: Optional[nn.Module], + ) -> OutputValue: + acc = None + for i, (output, reducer) in enumerate(zip(outputs, reducers)): + acc = reducer(acc, output, original_output, i, outputs) + return acc + + return reducer_fn + + reducer_fn = create_reducer() + if requires_original_output: + output_override = PassageOutputAdapter(reducer_fn) + else: + output_override = reducer_fn(None, None, None) + + output_overrides[group] = output_override + + return output_overrides + + @override + def __call__( + self, + input_overrides: dict[str, Any], + output_overrides: dict[str, Any], + *args, + **kwargs, + ) -> StitchedModuleOutput: + return super().__call__(input_overrides, output_overrides, *args, **kwargs) + + @override + @dynamo_skip + def forward( + self, + input_overrides: dict[str, Any], + output_overrides: dict[str, Any], + *args, + **kwargs, + ) -> StitchedModuleOutput: + input_overrides = {k: InputArgs.from_value(v) for k, v in input_overrides.items()} + + self.values_from_node.clear() + self.values_to_node.clear() + + unresolved_count: int = 0 + nodes_stack: list[Node] = ( + [] if self.external_node is None else [self.external_node] + ) + self.internal_nodes + while len(nodes_stack) > 0: + node = nodes_stack.pop(0) + values_from_node = self.values_from_node[node] + values_to_node = self.values_to_node[node] + + if isinstance(node.target, ExternalTarget): + assert self.external_node is not None + + if not self.ignore_extra_overrides: + input_override_names = set(input_overrides.keys()) + external_node_input_names = set( + s.source_descriptor.input_name + for s in self.external_node.stitches_from + if isinstance(s.source_descriptor, InputDescriptor) + ) + assert input_override_names == external_node_input_names + output_override_names = set(output_overrides.keys()) + external_node_output_names = set( + s.source_descriptor.output_name + for s in self.external_node.stitches_from + if isinstance(s.source_descriptor, OutputDescriptor) + ) + assert output_override_names == external_node_output_names + + for stitch in self.external_node.stitches_from: + if isinstance(stitch.source_descriptor, InputDescriptor): + orig_input_override = input_overrides[stitch.source_descriptor.input_name] + input_override = stitch.source_descriptor.input_adapter(orig_input_override) + values_from_node[stitch.source_descriptor] = input_override + elif isinstance(stitch.source_descriptor, OutputDescriptor): + orig_output_override = output_overrides[ + stitch.source_descriptor.output_name + ] + output_override = stitch.source_descriptor.output_adapter( + orig_output_override + ) + values_from_node[stitch.source_descriptor] = output_override + else: + raise RuntimeError("Shouldn't happen") + + else: + if len(values_to_node) < len(node.stitches_to): + nodes_stack.append(node) + unresolved_count += 1 + if unresolved_count >= len(nodes_stack): + raise CantResolveNodeDependenciesException( + "Can't resolve nodes dependencies" + ) + continue + + if isinstance(node.target, ConstantTarget): + assert len(values_to_node) == 0 + + output_value = node.target.value + + for stitch in node.stitches_from: + assert isinstance(stitch.source_descriptor, OutputDescriptor) + assert stitch.source_descriptor.output_name == "" + value = stitch.source_descriptor.output_adapter(output_value) + values_from_node[stitch.source_descriptor] = value + + elif isinstance(node.target, FunctionTarget): + assert all( + isinstance(v, InputDescriptor) and v.input_name == "" + for v in values_to_node + ) + + function_input_overrides = self.create_input_overrides(values_to_node)[""] + + if isinstance(function_input_overrides, InputArgs): + input_args = function_input_overrides + else: + input_args = function_input_overrides(InputArgs(), None, None) + + function_output = node.target.function(*input_args.args, **input_args.kwargs) + + for stitch in node.stitches_from: + assert isinstance(stitch.source_descriptor, OutputDescriptor) + assert stitch.source_descriptor.output_name == "" + value = stitch.source_descriptor.output_adapter(function_output) + values_from_node[stitch.source_descriptor] = value + + elif isinstance(node.target, ModuleTarget): + passage = self.node_passages[node] + passage.input_overrides = self.create_input_overrides(values_to_node) + passage.output_overrides = self.create_output_overrides(values_to_node) + passage_output = passage(*args, **kwargs) + + for stitch in node.stitches_from: + if isinstance(stitch.source_descriptor, InputDescriptor): + captured_input = passage_output.captured_inputs[ + stitch.source_descriptor.input_name + ] + value = stitch.source_descriptor.input_adapter(captured_input) + values_from_node[stitch.source_descriptor] = value + elif isinstance(stitch.source_descriptor, OutputDescriptor): + captured_output = passage_output.captured_outputs[ + stitch.source_descriptor.output_name + ] + value = stitch.source_descriptor.output_adapter(captured_output) + values_from_node[stitch.source_descriptor] = value + else: + raise RuntimeError("Shouldn't happen") + + elif isinstance(node.target, RemoteTarget): + assert all( + isinstance(v, OutputDescriptor) and v.output_name != "" + for v in values_from_node + ) + assert all( + isinstance(v, OutputDescriptor) and v.output_name != "" + for v in values_to_node + ) + + process_group = node.target.process_group + peers = node.target.peer_rank + if not isinstance(peers, Sequence): + peers = [peers] + + if len(values_to_node) > 0: + items_to_send = list(self.create_output_overrides(values_to_node).items()) + + data_descriptors: list[RemoteDataDescriptor] = [] + tensors_to_send: list[torch.Tensor] = [] + + for key, value in items_to_send: + if isinstance(value, torch.Tensor): + if value.is_cuda: + tensor_device = "cuda" + elif value.is_cpu: + tensor_device = "cpu" + else: + raise RuntimeError( + f"Invalid tensor device to send to remote target: {value.device}" + ) + + data_descriptor = RemoteTensorDataDescriptor( + key=key, + device=tensor_device, + dtype=value.dtype, + shape=value.shape, + ) + tensors_to_send.append(value) + + else: + data_descriptor = RemotePythonDataDescriptor( + key=key, + value=value, + ) + + data_descriptors.append(data_descriptor) + + works: list[Optional[torch.distributed.Work]] = [] + for peer in peers: + if process_group is not None: + peer = torch.distributed.get_global_rank(process_group, peer) + + peer_works = distributed_isend_obj(data_descriptors, dst=peer) + works.extend(peer_works) + + for tensor in tensors_to_send: + work = torch.distributed.isend(tensor, dst=peer) + works.append(work) + + if node.target.blocking: + for work in works: + if work is not None: + work.wait() + + pass + + if len(node.stitches_from) > 0: + assert len(peers) == 1, ( + f"Cannot use multiple peers when using RemoteTarget as a source ({peers=})" + ) + (peer,) = peers + + if process_group is not None: + peer = torch.distributed.get_global_rank(process_group, peer) + + data_descriptors = distributed_recv_obj(src=peer) + assert isinstance(data_descriptors, list) + + tensors_to_recv: list[torch.Tensor] = [] + received_values: dict[str, Any] = {} + for data_descriptor in data_descriptors: + if isinstance(data_descriptor, RemoteTensorDataDescriptor): + tensor = torch.empty( + data_descriptor.shape, + dtype=data_descriptor.dtype, + device=data_descriptor.device, + ) + tensors_to_recv.append(tensor) + received_values[data_descriptor.key] = tensor + elif isinstance(data_descriptor, RemotePythonDataDescriptor): + received_values[data_descriptor.key] = data_descriptor.value + else: + raise RuntimeError( + f"Received invalid data descriptor from remote peer: {data_descriptor}" + ) + + works: list[Optional[torch.distributed.Work]] = [] + for tensor in tensors_to_recv: + work = torch.distributed.irecv(tensor, src=peer) + works.append(work) + + for work in works: + if work is not None: + work.wait() + + for stitch in node.stitches_from: + if isinstance(stitch.source_descriptor, OutputDescriptor): + remote_output = received_values[ + stitch.source_descriptor.output_name + ] + value = stitch.source_descriptor.output_adapter(remote_output) + values_from_node[stitch.source_descriptor] = value + else: + raise RuntimeError("Shouldn't happen") + else: + raise RuntimeError("Shouldn't happen") + + for stitch in node.stitches_from: + dst_node = self.nodes[stitch.destination_descriptor.target] + value = values_from_node[stitch.source_descriptor] + + if isinstance(stitch.destination_descriptor, InputDescriptor): + value = stitch.destination_descriptor.input_adapter(value) + elif isinstance(stitch.destination_descriptor, OutputDescriptor): + value = stitch.destination_descriptor.output_adapter(value) + else: + raise RuntimeError("Shouldn't happen") + + self.values_to_node[dst_node][stitch.destination_descriptor] = value + + unresolved_count = 0 + + values_to_external_node = ( + {} if self.external_node is None else self.values_to_node[self.external_node] + ) + output = StitchedModuleOutput( + captured_inputs={ + k.input_name: v + for k, v in values_to_external_node.items() + if isinstance(k, InputDescriptor) + }, + captured_outputs={ + k.output_name: v + for k, v in values_to_external_node.items() + if isinstance(k, OutputDescriptor) + }, + ) + + self.values_from_node.clear() + self.values_to_node.clear() + + return output + + +class KnotException(Exception): + pass + + +class LoopFoundException(KnotException): + pass + + +class InputsLoopFoundException(LoopFoundException): + pass + + +class OutputsLoopFoundException(LoopFoundException): + pass + + +class MultipleExternalNodesException(KnotException): + pass + + +class OnlyInternalNodesException(KnotException): + pass + + +class Needle: + def __init__(self) -> None: + self.nodes = dict[Target, Node]() + + def get_node_for_target(self, target: Target) -> Node: + if target not in self.nodes: + node = Node(target=target) + self.nodes[target] = node + else: + node = self.nodes[target] + + return node + + def stitch(self, src: IODescriptor, dst: IODescriptor) -> Self: + descriptor = StitchDescriptor(source_descriptor=src, destination_descriptor=dst) + + src_node = self.get_node_for_target(descriptor.source_descriptor.target) + dst_node = self.get_node_for_target(descriptor.destination_descriptor.target) + + if descriptor not in src_node.stitches_from: + src_node.stitches_from.append(descriptor) + + if descriptor not in dst_node.stitches_to: + dst_node.stitches_to.append(descriptor) + + return self + + def _search_loops( + self, + node: Node, + expand_fn: Callable[[Node], Iterable[IODescriptor]], + traversed_nodes: Optional[set[Node]] = None, + ) -> bool: + if isinstance(node.target, ExternalTarget): + return False + + if traversed_nodes is None: + traversed_nodes = set() + + if node in traversed_nodes: + found_loop = True + else: + traversed_nodes = traversed_nodes | {node} + found_loop = False + descriptors = expand_fn(node) + for descriptor in descriptors: + stitch_node = self.get_node_for_target(descriptor.target) + found_loop |= self._search_loops(stitch_node, expand_fn, traversed_nodes) + + return found_loop + + def _validate_nodes(self): + # internal_nodes = [n for n in self.nodes.values() if not isinstance(n.target, (ExternalTarget, RemoteTarget))] + external_nodes = [n for n in self.nodes.values() if isinstance(n.target, ExternalTarget)] + remote_nodes = [n for n in self.nodes.values() if isinstance(n.target, RemoteTarget)] + + if len(external_nodes) + len(remote_nodes) == 0: + raise OnlyInternalNodesException(f"Has only internal nodes") + + if len(external_nodes) > 1: + raise MultipleExternalNodesException( + f"Expected no more than 1 external node, found {len(external_nodes)}" + ) + + for i, node in enumerate(self.nodes.values()): + found_inputs_loop = self._search_loops( + node, lambda n: [s.source_descriptor for s in n.stitches_to] + ) + if found_inputs_loop: + raise InputsLoopFoundException(f"Found a loop in inputs of node {i}: {node}") + + found_outputs_loop = self._search_loops( + node, lambda n: [s.destination_descriptor for s in n.stitches_from] + ) + if found_outputs_loop: + raise OutputsLoopFoundException(f"Found a loop in outputs of node {i}: {node}") + + def knot( + self, + capture_cache_outputs_predicate=always_false_predicate, + early_exit=True, + ignore_extra_overrides=False, + ) -> StitchedModule: + self._validate_nodes() + + module = StitchedModule( + nodes=self.nodes, + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + ignore_extra_overrides=ignore_extra_overrides, + ) + + return module diff --git a/modelopt/torch/_compress/sewing_kit/passage/__init__.py b/modelopt/torch/_compress/sewing_kit/passage/__init__.py new file mode 100644 index 000000000..98dfc683b --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/passage/__init__.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .core import ( + Passage, + PassageOutput, + InputArgs, + OutputValue, + Predicate, + PassageInputAdapter, + PassageOutputAdapter, + PassageInputOverrides, + PassageOutputOverrides, + always_true_predicate, + always_false_predicate, +) diff --git a/modelopt/torch/_compress/sewing_kit/passage/core.py b/modelopt/torch/_compress/sewing_kit/passage/core.py new file mode 100644 index 000000000..9c200d569 --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/passage/core.py @@ -0,0 +1,462 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +from __future__ import annotations +import sys + +if sys.version_info[1] < 9: # if less than pytorch 3.9 + from typing import Sequence, Callable +else: + from collections.abc import Sequence, Callable + +from dataclasses import dataclass +from typing import Any, ContextManager, Iterable, Mapping, Optional, Union + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + +from typing_extensions import override + +import torch.nn as nn +from ..utils import ( + ActivityContext, + has_fake_tensor, + fake_tensors, + is_submodule_of, + is_submodule_or_same, + real_tensors, + dynamo_skip, +) +from ..common import logger + + +@dataclass +class InputArgs: + args: list[Any] + kwargs: dict[str, Any] + + def __init__(self, *args, **kwargs): + self.args = list(args) + self.kwargs = dict(kwargs) + + def __add__(self, other: Any) -> InputArgs: + assert isinstance(other, InputArgs) + result = InputArgs(*self.args, *other.args, **{**self.kwargs, **other.kwargs}) + return result + + def drop_args(self, index: int | slice | None = None) -> InputArgs: + new_args = InputArgs(*self.args, **self.kwargs) + if index is None: + new_args.args.clear() + else: + del new_args.args[index] + + return new_args + + def drop_kwargs(self, keys: Sequence[str] | None = None) -> InputArgs: + new_args = InputArgs(*self.args, **self.kwargs) + if keys is None: + new_args.kwargs.clear() + else: + for key in keys: + new_args.kwargs.pop(key, None) + + return new_args + + @classmethod + def from_value(cls, v): + if isinstance(v, cls): + return v + elif isinstance(v, InputArgs): + return cls(*v.args, **v.kwargs) + elif isinstance(v, Sequence): + return cls(*v) + else: + return cls(v) + + +OutputValue = Any + + +@dataclass +class PassageInputAdapter: + adapter_fn: Callable[[InputArgs, Optional[str], Optional[nn.Module]], InputArgs] + + def __call__( + self, original_input: InputArgs, module_name: Optional[str], module: Optional[nn.Module] + ) -> InputArgs: + result = self.adapter_fn(original_input, module_name, module) + return result + + +@dataclass +class PassageOutputAdapter: + adapter_fn: Callable[[Any, Optional[str], Optional[nn.Module]], Any] + + def __call__( + self, original_output: Any, module_name: Optional[str], module: Optional[nn.Module] + ) -> Any: + result = self.adapter_fn(original_output, module_name, module) + return result + + +class PassageInputOverrides(dict[str, Union[PassageInputAdapter, InputArgs]]): + def __init__(self, input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}): + for k, v in input_overrides.items(): + self[k] = v + + # def __setitem__(self, key: str, value: InputAdapter | InputArgs) -> None: + # if isinstance(key, InputArgs): + # def adapter_fn(original_input: InputArgs) -> InputArgs: + # assert isinstance(value, InputArgs) + # return value + # self[key] = InputAdapter(adapter_fn) + # else: + # self[key] = value + + +class PassageOutputOverrides(dict[str, Union[PassageOutputAdapter, Any]]): + def __init__(self, output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}): + for k, v in output_overrides.items(): + self[k] = v + + +class NoActivePassageContextError(RuntimeError): + pass + + +class RequiredPassageOutputsCapturedSignal(Exception): + pass + + +@dataclass +class PassageOutput: + captured_inputs: dict[str, InputArgs] + captured_outputs: dict[str, Any] + captured_fake_outputs: dict[str, Any] + module_output: Any + + +Predicate = Callable[[str, nn.Module], bool] + + +def always_false_predicate(module_name: str, module: nn.Module) -> bool: + return False + + +def always_true_predicate(module_name: str, module: nn.Module) -> bool: + return True + + +class Passage(nn.Module): + create_fn_context = ActivityContext[None](max_depth=1) + active_passages_context = ActivityContext["Passage"](no_duplicates=True, reversed=True) + + def __init__( + self, + module: nn.Module, + *, + inputs_to_capture: Iterable[str] = [], + outputs_to_capture: Iterable[str] = [], + input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}, + output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}, + outputs_cache: dict[str, Any] = {}, + capture_fake_outputs_predicate: Predicate = always_false_predicate, + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit: bool = False, + name: Optional[str] = None, + ): + super().__init__() + + if not self.create_fn_context.is_active(): + raise RuntimeError("Please use Passage.create(...) in order to create a new Passage") + + self.active_context_manager: Optional[ContextManager] = None + + self.name = name + self.module = module + self.module_to_name_mapping = {id(v): k for k, v in module.named_modules()} + self.inputs_to_capture = set(inputs_to_capture) + self.outputs_to_capture = set(outputs_to_capture) + self.input_overrides = input_overrides + self.output_overrides = output_overrides + self.outputs_cache = outputs_cache + self.capture_fake_outputs_predicate = capture_fake_outputs_predicate + self.capture_cache_outputs_predicate = capture_cache_outputs_predicate + self.early_exit = early_exit + + self.reset() + + @property + def input_overrides(self) -> PassageInputOverrides: + return self._input_overrides + + @input_overrides.setter + def input_overrides(self, value: Mapping[str, PassageInputAdapter | InputArgs]): + self._input_overrides = PassageInputOverrides(value) + + @property + def output_overrides(self) -> PassageOutputOverrides: + return self._output_overrides + + @output_overrides.setter + def output_overrides(self, value: Mapping[str, PassageOutputAdapter | Any]): + self._output_overrides = PassageOutputOverrides(value) + + def reset(self): + self.required_capture_count = ( + (len(self.inputs_to_capture) + len(self.outputs_to_capture)) + if self.early_exit + else None + ) + self.captured_outputs: dict[str, Any] = {} + self.captured_inputs: dict[str, InputArgs] = {} + self.captured_fake_outputs: dict[str, Any] = {} + + @classmethod + def module_name_relative_to_active_passage(cls, module: PatchedModule) -> str: + root_passage = Passage.active_passages_context.get_active() + assert root_passage is not None + module_name = root_passage.module_to_name_mapping[id(module)] + return module_name + + @classmethod + def create( + cls, + module: nn.Module, + *, + inputs_to_capture: Iterable[str] = [], + outputs_to_capture: Iterable[str] = [], + input_overrides: Mapping[str, PassageInputAdapter | InputArgs] = {}, + output_overrides: Mapping[str, PassageOutputAdapter | Any] = {}, + outputs_cache: dict[str, Any] = {}, + capture_fake_outputs_predicate: Predicate = always_false_predicate, + capture_cache_outputs_predicate: Predicate = always_false_predicate, + early_exit: bool = False, + name: Optional[str] = None, + ) -> Passage: + with cls.create_fn_context(None): + passage = cls( + module=module, + inputs_to_capture=inputs_to_capture, + outputs_to_capture=outputs_to_capture, + input_overrides=input_overrides, + output_overrides=output_overrides, + outputs_cache=outputs_cache, + capture_fake_outputs_predicate=capture_fake_outputs_predicate, + capture_cache_outputs_predicate=capture_cache_outputs_predicate, + early_exit=early_exit, + name=name, + ) + + for submodule_name, submodule in module.named_modules(remove_duplicate=False): + patch_module(submodule_name, submodule) + + # register_passage_hooks(module, descriptor) + + return passage + + def is_active(self) -> bool: + result = self.active_context_manager is not None + return result + + def __enter__(self): + assert self.active_context_manager is None + self.active_context_manager = Passage.active_passages_context(self) + self.active_context_manager.__enter__() + self.module_to_name_mapping = {id(v): k for k, v in self.named_modules()} + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.active_context_manager is not None + self.active_context_manager.__exit__(exc_type, exc_val, exc_tb) + + def freeze(self): + self.eval() + self.requires_grad_(False) + + def unfreeze(self): + self.train() + self.requires_grad_(True) + + def run(self, *args, **kwargs) -> PassageOutput: + return self(*args, **kwargs) + + @override + def __call__(self, *args, **kwargs) -> PassageOutput: + return super().__call__(*args, **kwargs) + + @dynamo_skip + @override + def forward(self, *args, **kwargs) -> PassageOutput: + self.reset() + + with Passage.active_passages_context(self): + try: + module_output = self.module(*args, **kwargs) + except RequiredPassageOutputsCapturedSignal: + module_output = None + + output = PassageOutput( + captured_inputs=self.captured_inputs, + captured_outputs=self.captured_outputs, + captured_fake_outputs=self.captured_fake_outputs, + module_output=module_output, + ) + + self.reset() + + return output + + +class PatchedModule: ... + + +def patch_module(module_name_: str, module: nn.Module): + # orig_forward = module.forward + + if isinstance(module, PatchedModule): + # if module_name != Passage.module_name_relative_to_active_passage(module): + # logger.warn(f'Module "{module_name}" already patched for module "{Passage.module_name_relative_to_active_passage(module)}". Could lead to bugs.') + return + + orig_class = module.__class__ + + class PassageModuleWrapper(orig_class, PatchedModule): + # Defined as a static method to avoid potential collision with original class methods + @staticmethod + @dynamo_skip + def can_be_skipped(_self: PassageModuleWrapper, depth: int) -> bool: + passages_beyond_depth = Passage.active_passages_context[depth:] + module_name = Passage.module_name_relative_to_active_passage(_self) + + results = [ + ( + module_name in passage.outputs_cache + and not any( + is_submodule_or_same(k, module_name) for k in passage.outputs_to_capture + ) + and not any( + is_submodule_of(k, module_name) + for k, v in passage.input_overrides.items() + if v is not None + ) + and not any( + is_submodule_of(k, module_name) + for k, v in passage.output_overrides.items() + if v is not None + ) + ) + for passage in passages_beyond_depth + ] + + result = all(results) + + return result + + # Defined as a static method to avoid potential collision with original class methods + @staticmethod + @dynamo_skip + def run_passage(_self: PassageModuleWrapper, depth: int, args, kwargs): + if depth + 1 > len(Passage.active_passages_context): + output = super(PassageModuleWrapper, _self).__call__(*args, **kwargs) + return output + + module_name = Passage.module_name_relative_to_active_passage(_self) + passage = Passage.active_passages_context[depth] + + has_output_override = module_name in passage.output_overrides + output_override = passage.output_overrides.get(module_name) + + if has_output_override and not isinstance(output_override, PassageOutputAdapter): + output = output_override + else: + input_override = passage.input_overrides.get(module_name) + if input_override is not None: + original_input_args = InputArgs(*args, **kwargs) + + if isinstance(input_override, PassageInputAdapter): + new_input_args = input_override(original_input_args, module_name, module) + else: + new_input_args = input_override + + args, kwargs = new_input_args.args, new_input_args.kwargs + + if ( + output_override is None + and PassageModuleWrapper.can_be_skipped(_self, depth) + and (has_fake_tensor(args) or has_fake_tensor(kwargs)) + ): + cached_output = passage.outputs_cache[module_name] + return cached_output + + output = PassageModuleWrapper.run_passage( + _self=_self, + depth=depth + 1, + args=args, + kwargs=kwargs, + ) + + if isinstance(output_override, PassageOutputAdapter): + output = output_override(output, module_name, module) + + if passage.capture_fake_outputs_predicate(module_name, module): + fake_output = fake_tensors(output) + passage.captured_fake_outputs[module_name] = fake_output + + if not module_name in passage.outputs_cache and passage.capture_cache_outputs_predicate( + module_name, module + ): + fake_output = fake_tensors(output) + passage.outputs_cache[module_name] = fake_output + + if module_name in passage.inputs_to_capture: + real_args, real_kwargs = real_tensors(args), real_tensors(kwargs) + passage.captured_inputs[module_name] = InputArgs(*real_args, **real_kwargs) + + if passage.required_capture_count is not None: + passage.required_capture_count -= 1 + + if module_name in passage.outputs_to_capture: + real_output = real_tensors(output) + output_value = real_output + passage.captured_outputs[module_name] = output_value + + if passage.required_capture_count is not None: + passage.required_capture_count -= 1 + + if passage.required_capture_count == 0: + raise RequiredPassageOutputsCapturedSignal() + + return output + + @dynamo_skip + @override + def __call__(self, *args, **kwargs): + output = self.run_passage( + _self=self, + depth=0, + args=args, + kwargs=kwargs, + ) + return output + + # module.forward = forward + PassageModuleWrapper.__name__ = f"ModuleWrapper({module.__class__.__name__})" + module.__class__ = PassageModuleWrapper diff --git a/modelopt/torch/_compress/sewing_kit/passage/recipes/__init__.py b/modelopt/torch/_compress/sewing_kit/passage/recipes/__init__.py new file mode 100644 index 000000000..47f1c65a1 --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/passage/recipes/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/_compress/sewing_kit/utils.py b/modelopt/torch/_compress/sewing_kit/utils.py new file mode 100644 index 000000000..ebe90b2a4 --- /dev/null +++ b/modelopt/torch/_compress/sewing_kit/utils.py @@ -0,0 +1,506 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +from __future__ import annotations + +import inspect +from collections.abc import Sequence, Mapping +from contextlib import contextmanager +from typing import ( + Any, + Callable, + ContextManager, + Generic, + Iterable, + Literal, + Optional, + Protocol, + TypeVar, + cast, + overload, +) +from typing_extensions import override +import torch +import torch.distributed +import torch._dynamo +import torch._C +from torch import Tensor +import torch.utils._pytree as pytree +import torch.nn as nn +import torch.nn.functional as F +from torch._subclasses import FakeTensor, FakeTensorMode + + +Fn = TypeVar("Fn", bound=Callable) + + +class DynamoSkip(Protocol): + @overload + def __call__(self, fn: None = None) -> Callable[[Fn], Fn]: ... + @overload + def __call__(self, fn: Fn) -> Fn: ... + + +class DynamoDisable(Protocol): + @overload + def __call__(self, fn: None = None, disable: bool = False) -> Callable[[Fn], Fn]: ... + @overload + def __call__(self, fn: Fn, disable: bool = False) -> Fn: ... + + +try: + dynamo_skip: DynamoSkip = cast(Any, torch._dynamo.decorators).skip + dynamo_disable: DynamoDisable = cast(Any, torch._dynamo.decorators).disable +except: + dynamo_skip: DynamoSkip = cast(Any, torch._dynamo.eval_frame).skip + dynamo_disable: DynamoDisable = cast(Any, torch._dynamo.eval_frame).disable + + +TModule = TypeVar("TModule", bound=nn.Module) + + +class ModuleRef(Generic[TModule]): + def __init__(self, module: TModule): + self.module = module + + +Reduction = Literal["none", "mean", "sum"] + + +def normalized_mse_loss( + input: Tensor, target: Tensor, reduction: Reduction = "mean", epsilon: float = 1e-6 +): + loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction=reduction + ) + return loss + + +def mse_loss(input: Tensor, target: Tensor, reduction: Reduction = "mean", epsilon: float = 1e-6): + loss = F.mse_loss(input, target, reduction=reduction) + return loss + + +class NormalizedMSELoss(nn.modules.loss._Loss): + __constants__ = ["reduction", "epsilon"] + + def __init__(self, reduction: Reduction = "mean", epsilon: float = 1e-6) -> None: + super().__init__(None, None, reduction) + self.epsilon = epsilon + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + loss = normalized_mse_loss( + input, + target, + cast(Reduction, self.reduction), + self.epsilon, + ) + return loss + + +def vectorwise_normalized_mse_loss(input: Tensor, target: Tensor, epsilon: float = 1e-6): + """ + Like normalized_mse_loss, but the input is treated as a multi-dimensional batch of vectors. + Normalization is done on each vector separately (the last dim), then results are averaged. + """ + return batched_normalized_mse_loss(input, target, epsilon, batch_dims=range(input.ndim - 1)) + + +def batched_normalized_mse_loss( + input: Tensor, target: Tensor, epsilon: float = 1e-6, batch_dims: Sequence[int] = (0,) +): + """ + Like normalized_mse_loss, but the input is treated as a batch of tensors. + Normalization is done on the non-batch dims, then results are averaged. + """ + norm_dims = list(set(range(input.ndim)) - set(batch_dims)) + norm_of_target_vectors = F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction="none" + ).mean(dim=norm_dims) + vectorwise_mse = F.mse_loss(input, target, reduction="none").mean(dim=norm_dims) + normalized_vectorwise_mse = vectorwise_mse / norm_of_target_vectors + loss = normalized_vectorwise_mse.mean() + return loss + + +class ActivityContextMaxDepthException(Exception): + pass + + +class ActivityContextDuplicateException(Exception): + pass + + +T = TypeVar("T") + + +class ActivityContext(Generic[T]): + def __init__(self, max_depth: Optional[int] = None, no_duplicates=False, reversed=False): + self.activity_stack: list[T] = [] + self.max_depth = max_depth + self.no_duplicates = no_duplicates + self.reversed = reversed + self.head_index = 0 if self.reversed else -1 + + def __contains__(self, value: T) -> bool: + result = value in self.activity_stack + return result + + def __call__(self, value: T) -> ContextManager: + @contextmanager + def fn(): + try: + if self.no_duplicates and value in self.activity_stack: + raise ActivityContextDuplicateException( + f"Activity stack cannot have a duplicate of item {value}" + ) + + self.activity_stack.insert(self.head_index, value) + + if self.max_depth is not None and len(self) > self.max_depth: + raise ActivityContextMaxDepthException( + f"Activity stack exceeds max depth of {self.max_depth}" + ) + + yield + finally: + assert self.is_active() + self.activity_stack.pop(self.head_index) + + return fn() + + def __len__(self) -> int: + result = len(self.activity_stack) + return result + + @overload + def __getitem__(self, key: int) -> T: ... + @overload + def __getitem__(self, key: slice) -> Sequence[T]: ... + def __getitem__(self, key: int | slice) -> T | Sequence[T]: + result = self.activity_stack[key] + return result + + def is_active(self) -> bool: + result = len(self) > 0 + return result + + def get_active(self) -> Optional[T]: + if self.is_active: + return self.activity_stack[-1] + else: + return None + + +def is_submodule_of(module_name: str, other_module_name: str) -> bool: + result = module_name.startswith(f"{other_module_name}.") or ( + module_name != "" and other_module_name == "" + ) + return result + + +def is_submodule_or_same(module_name: str, other_module_name: str) -> bool: + result = module_name == other_module_name or is_submodule_of(module_name, other_module_name) + return result + + +def reduce_losses(losses: Iterable[Tensor]) -> Tensor: + total_loss = None + for loss in losses: + if total_loss is None: + total_loss = loss + else: + total_loss += loss + + if total_loss is None: + return torch.Tensor(torch.nan) + + return total_loss + + +fake_mode = FakeTensorMode( + allow_non_fake_inputs=True, + # allow_fallback_kernels=False, +) + + +@overload +def fake_tensor(t: Tensor, *, dtype: Optional[torch.dtype] = None, use_meta=False) -> Tensor: ... + + +@overload +def fake_tensor( + size: Sequence[int] | torch.Size, *, dtype: Optional[torch.dtype] = None, use_meta=False +) -> Tensor: ... + + +@overload +def fake_tensor(*args: int, dtype: Optional[torch.dtype] = None, use_meta=False) -> Tensor: ... + + +class MyFakeTensor(Tensor): + @dynamo_disable + def __init__(self, *args, **kwargs): + super().__init__() + self._t: FakeTensor + + @override + @dynamo_disable + def __repr__(self, *, tensor_contents=None): + return f"MyFakeTensor(shape={list(self._t.shape)}, dtype={self._t.dtype}, device={self._t.device})" + + @classmethod + @override + @dynamo_disable + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + args, kwargs = pytree.tree_map_only(MyFakeTensor, lambda t: t._t, (args, kwargs)) + + types = pytree.tree_map_only(type(MyFakeTensor), lambda t: FakeTensor, types) + + out = func(*args, **kwargs) + + out = pytree.tree_map_only(Tensor, lambda t: MyFakeTensor.create(t), out) + + return out + + __torch_function__ = torch._C._disabled_torch_function_impl + + # @dynamo_disable + # def __getattribute__(self, attr: str): + # if attr in {'_t', 'device', '__repr__', '__torch_function__', '__class__'}: + # return object.__getattribute__(self, attr) + + # result = getattr(self._t, attr) + + # result = pytree.tree_map_only( + # Tensor, lambda t: MyFakeTensor.create(t), result + # ) + # print('__getattribute__', 'attr', attr, 'ret', result) + + # return result + + @property + @dynamo_disable + def device(self): + return self._t.device + + # @property + # @dynamo_disable + # def shape(self): + # return self._t.shape + + # @dynamo_disable + # def size(self): + # return self._t.size() + + # @classmethod + # @dynamo_disable + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # if kwargs is None: + # kwargs = {} + + # args, kwargs = pytree.tree_map_only( + # MyFakeTensor, lambda t: t._t, (args, kwargs) + # ) + + # ret = func(*args, **kwargs) + + # ret = pytree.tree_map_only( + # Tensor, lambda t: MyFakeTensor.create(t), ret + # ) + # print('__torch_function__', 'func', func, 'ret', ret) + + # return ret + + @staticmethod + @dynamo_disable + def __new__(cls, elem, device) -> MyFakeTensor: + self = torch.Tensor._make_subclass( + cls, + elem, + elem.requires_grad, + dispatch_device=True, + device_for_backend_keys=device, + ) + return cast(MyFakeTensor, self) + + @classmethod + @dynamo_disable + def create(cls, data: Tensor) -> MyFakeTensor: + if isinstance(data, MyFakeTensor): + return data + + if isinstance(data, FakeTensor): + t = data + else: + t = FakeTensor.from_tensor(data, fake_mode=fake_mode) + + # my_fake_tensor = MyFakeTensor(torch.empty(t.shape, dtype=t.dtype, device='meta')) + my_fake_tensor = MyFakeTensor( + torch.empty(t.shape, dtype=t.dtype, device="meta"), + t.device, + ) + my_fake_tensor._t = t + + return my_fake_tensor + + +@dynamo_disable +def fake_tensor(*args, **kwargs) -> Tensor: + dtype: Optional[torch.dtype] = kwargs.get("dtype") + use_meta = kwargs.get("use_meta", False) + + if len(args) == 1 and isinstance(args[0], Tensor): + if use_meta: + fake_tensor = torch.empty(args[0].size(), dtype=dtype or args[0].dtype, device="meta") + else: + fake_tensor = MyFakeTensor.create(args[0]) + else: + fake_tensor = torch.empty(*args, dtype=dtype, device="meta") + if not use_meta: + fake_tensor = MyFakeTensor.create(fake_tensor) + + return fake_tensor + + +@dynamo_skip +def fake_tensor_like(t: Tensor, use_meta=False) -> Tensor: + return fake_tensor(t, use_meta=use_meta) + + +T = TypeVar("T") + + +@dynamo_skip +def fake_tensors(value: T, use_meta=False) -> T: + result = pytree.tree_map_only(Tensor, lambda t: fake_tensor_like(t, use_meta), value) + return result + # if isinstance(value, Mapping): + # return cast(Any, value.__class__)({k: fake_tensors(v, use_meta) for k, v in value.items()}) + # if isinstance(value, Sequence): + # return cast(Any, value.__class__)([fake_tensors(v, use_meta) for v in value]) + # if isinstance(value, Tensor): + # return fake_tensor_like(value, use_meta) + # return value + + +@dynamo_skip +def real_tensors(value: Any) -> Any: + result = pytree.tree_map_only(Tensor, lambda t: None if is_fake_tensor(t) else t, value) + return result + # if isinstance(value, Mapping): + # return cast(Any, value.__class__)({k: real_tensors(v) for k, v in value.items()}) + # if isinstance(value, Sequence): + # return cast(Any, value.__class__)([real_tensors(v) for v in value]) + # if is_fake_tensor(value): + # return None + # return value + + +@dynamo_skip +def is_fake_tensor(t: Any) -> bool: + return isinstance(t, (MyFakeTensor, FakeTensor)) or (isinstance(t, Tensor) and t.is_meta) + + +@dynamo_skip +def has_fake_tensor(v: Any) -> bool: + result = pytree.tree_any(is_fake_tensor, v) + return result + + +@dynamo_skip +def is_real_tensor(t: Any) -> bool: + return isinstance(t, Tensor) and not t.is_meta and not isinstance(t, FakeTensor) + + +@dynamo_skip +def get_parent_module_name(module_name: str): + if "." not in module_name: + return "" + else: + return module_name.rsplit(".", 1)[0] + + +@dynamo_skip +def get_parent_module_names(module_name: str): + parent_module_names = set[str]() + + while len(module_name) > 0: + module_name = get_parent_module_name(module_name) + parent_module_names.add(module_name) + + return parent_module_names + + +def distributed_isend_obj( + obj: Any, + dst: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, +) -> list[Optional[torch.distributed.Work]]: + obj_tensor, obj_size_tensor = torch.distributed.distributed_c10d._object_to_tensor( + obj, device="cpu", **_get_group_kwarg_if_necessary() + ) + works: list[Optional[torch.distributed.Work]] = [ + torch.distributed.isend(obj_size_tensor, dst, group), + torch.distributed.isend(obj_tensor, dst, group), + ] + # p2p_ops = [ + # torch.distributed.P2POp(torch.distributed.isend, obj_size_tensor, dst, group), + # torch.distributed.P2POp(torch.distributed.isend, obj_tensor, dst, group), + # ] + + # works = torch.distributed.batch_isend_irecv(p2p_ops) + + return works + + +def distributed_send_obj( + obj: Any, + dst: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, +): + works = distributed_isend_obj(obj=obj, dst=dst, group=group) + for work in works: + if work is not None: + work.wait() + + +def distributed_recv_obj( + src: Optional[int] = None, + group: Optional[torch.distributed.ProcessGroup] = None, +) -> Any: + obj_size_tensor = torch.LongTensor(1, device="cpu") + torch.distributed.recv(obj_size_tensor, src=src, group=group) + obj_size = int(obj_size_tensor.item()) + + obj_tensor = torch.ByteTensor(obj_size, device="cpu") + torch.distributed.recv(obj_tensor, src=src, group=group) + + obj = torch.distributed.distributed_c10d._tensor_to_object( + obj_tensor, obj_size, **_get_group_kwarg_if_necessary() + ) + + return obj + + +def _get_group_kwarg_if_necessary() -> dict: + """For newer versions of torch""" + arg_names = inspect.signature( + torch.distributed.distributed_c10d._object_to_tensor + ).parameters.keys() + return dict(group=None) if "group" in arg_names else dict() From a7a4adce982194d92a01f543b1d9acf777fa1ff0 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 6 Nov 2025 09:57:39 +0100 Subject: [PATCH 35/49] add sewing_kit Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/tools/kd_model.py | 343 ++++++++++++++++++ .../utils/validate_runtime_pipeline.py | 19 +- modelopt/torch/_compress/utils/validation.py | 2 +- 3 files changed, 359 insertions(+), 5 deletions(-) create mode 100644 modelopt/torch/_compress/tools/kd_model.py diff --git a/modelopt/torch/_compress/tools/kd_model.py b/modelopt/torch/_compress/tools/kd_model.py new file mode 100644 index 000000000..cbff6aab3 --- /dev/null +++ b/modelopt/torch/_compress/tools/kd_model.py @@ -0,0 +1,343 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +from abc import ABCMeta, abstractmethod +from typing import List, Callable, Literal, Tuple, Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +class Block(nn.Module): + def __init__(self, *args, **kwargs): + raise NotImplementedError("This class is deprecated. Deci models are now hf models.") + + +class DummyBlock(nn.Module): + def __init__(self, *args, **kwargs): + raise NotImplementedError("This class is deprecated. Deci models are now hf models.") + + +RoPECache = torch.Tensor + + +def normalized_mse_loss( + input: Tensor, + target: Tensor, + reduction: Literal["none", "mean", "sum"] = "mean", + epsilon: float = 1e-6, +) -> Tensor: + loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction=reduction + ) + return loss + + +def cosine_embedding_loss_batched(input: Tensor, target: Tensor) -> Tensor: + # inputs are of shape (B,T,H) + batch_size = input.size(0) + input = input.view(batch_size, -1) + target = target.view(batch_size, -1) + target_tensor = input.new(input.size(0)).fill_(1) + loss = F.cosine_embedding_loss( + input1=input, input2=target, target=target_tensor, reduction="none" + ) + return loss + + +def cross_entropy_probs_batched(logits_input: Tensor, logits_target: Tensor) -> Tensor: + return F.cross_entropy( + logits_input.transpose(1, 2), logits_target.softmax(-1).transpose(1, 2), reduction="none" + ).mean(-1) + + +def kl_div_logits_batched(logits_input: Tensor, logits_target: Tensor) -> Tensor: + return ( + F.kl_div( + logits_input.log_softmax(-1), + logits_target.log_softmax(-1), + reduction="none", + log_target=True, + ) + .sum(-1) + .mean(-1) + ) + + +def kl_div_single_sample(logits_input: Tensor, logits_target: Tensor) -> Tensor: + return F.kl_div( + logits_input.log_softmax(-1), + logits_target.log_softmax(-1), + reduction="batchmean", + log_target=True, + ) + + +def kl_div_logits_batched_mem_efficient(logits_input: Tensor, logits_target: Tensor) -> Tensor: + batch_size = logits_input.shape[0] + kl_div_per_sample = [ + kl_div_single_sample(logits_input[i], logits_target[i]) for i in range(batch_size) + ] + return torch.stack(kl_div_per_sample) + + +kl_div = kl_div_logits_batched_mem_efficient + + +def mse_loss( + x_input: torch.Tensor, + x_target: torch.Tensor, +) -> Tensor: + return torch.stack( + [F.mse_loss(x_input[i_sample], x_target[i_sample]) for i_sample in range(x_input.shape[0])] + ) + + +def reverse_kl_div(logits_input: Tensor, logits_target: Tensor) -> Tensor: + return kl_div_logits_batched_mem_efficient(logits_target, logits_input) + + +def tv_dist(logits_input: Tensor, logits_target: Tensor) -> Tensor: + """ + Total Variation Distance: L1-loss between probabilities. + vocab dimension is summed, sequence dimension is averaged. + """ + batch_size, seq_len, vocab_size = logits_input.shape + tv_dist_per_sample = [ + F.l1_loss(logits_input[i].softmax(-1), logits_target[i].softmax(-1), reduction="sum") + / seq_len + for i in range(batch_size) + ] + return torch.stack(tv_dist_per_sample) + + +def js_div(logits_input: Tensor, logits_target: Tensor) -> Tensor: + """ + Jensen-Shannon Divergence for a single sample. + logits: [tokens, vocab] + target_probs: [tokens, vocab] + """ + batch_size = logits_input.shape[0] + _js_div = [] + for i in range(batch_size): + input_probs = logits_input[i].softmax(-1) + target_probs = logits_target[i].softmax(-1) + mixture_probs = (input_probs + target_probs) * 0.5 + mixture_logprobs = mixture_probs.log().clip(min=-20) + pred_kl_div = kl_div_single_sample(mixture_logprobs, input_probs) + target_kl_div = kl_div_single_sample(mixture_logprobs, target_probs) + js_div_i = 0.5 * (pred_kl_div + target_kl_div) + _js_div.append(js_div_i) + return torch.stack(_js_div) + + +LOGITS_LOSS_NAME_TO_FUNC = {name: func for name, func in globals().items() if callable(func)} + + +class KDLossWeigher(metaclass=ABCMeta): + @abstractmethod + def __call__( + self, + lm_loss: Tensor, + kd_block_loss: Tensor, + kd_logits_loss: Tensor, + ) -> Tensor: + raise NotImplementedError() + + +class StaticKDLossWeigher(KDLossWeigher): + def __init__( + self, + lm_weight: float, + kd_block_weight: float, + kd_logits_weight: float, + ): + self.lm_weight = lm_weight + self.kd_block_weight = kd_block_weight + self.kd_logits_weight = kd_logits_weight + + def __call__( + self, + lm_loss: Tensor, + kd_block_loss: Tensor, + kd_logits_loss: Tensor, + ) -> Tuple: + lm_loss = self.lm_weight * lm_loss + kd_block_loss = self.kd_block_weight * kd_block_loss + kd_logits_loss = self.kd_logits_weight * kd_logits_loss + + loss = lm_loss + kd_block_loss + kd_logits_loss + return loss, lm_loss, kd_block_loss, kd_logits_loss + + +class KDModel(nn.Module): + def __init__( + self, + student_model, + teacher_model, + block_loss_func: Callable, + logits_loss_func: Callable, + kd_loss_weigher: StaticKDLossWeigher, + teacher_requires_rope: bool = False, + ): + super().__init__() + assert not student_model.abs_positional + student_uses_rope = student_model.config.position_embedding_type in ["rope", "rope_llama4"] + teacher_uses_rope = teacher_model.config.position_embedding_type in ["rope", "rope_llama4"] + assert (student_uses_rope and teacher_uses_rope) or ( + not student_uses_rope and not teacher_uses_rope + ), "We do not support mixed rope usage" + self.use_rope = student_uses_rope + + self.logits_loss_func = logits_loss_func + self.block_loss_func = block_loss_func + # teacher_model.eval() + # teacher_model.requires_grad_(False) + self.student_wte, self.teacher_wte = ( + student_model.transformer.wte, + teacher_model.transformer.wte, + ) + self.blocks = nn.ModuleList() + student_blocks = student_model.transformer.h + teacher_blocks = teacher_model.transformer.h + for i in range(max(len(teacher_blocks), len(student_blocks))): + student_block = student_blocks[i] if i < len(student_blocks) else DummyBlock() + teacher_block = teacher_blocks[i] if i < len(teacher_blocks) else DummyBlock() + combined_block = KDBlock( + student_block, teacher_block, teacher_requires_rope=teacher_requires_rope + ) + self.blocks.append(combined_block) + self.student_ln_f, self.teacher_ln_f = ( + student_model.transformer.ln_f, + teacher_model.transformer.ln_f, + ) + self.student_lm_head, self.teacher_lm_head = student_model.lm_head, teacher_model.lm_head + self.block_size = student_model.block_size + + for teacher_module in (self.teacher_wte, self.teacher_ln_f, self.teacher_lm_head): + teacher_module.eval() + teacher_module.requires_grad_(False) + self.kd_loss_weigher = kd_loss_weigher + + def forward( + self, + idx: Tensor, + rope_cache: Optional[RoPECache] = None, + max_seq_length: Optional[int] = None, + varlen: bool = False, + concat_token_id: Optional[int] = None, + input_pos: Optional[int] = None, + kv_caches: Optional[List[torch.Tensor]] = None, + is_decode: bool = False, + ) -> tuple[Tensor, Tensor, Tensor]: + B, T = idx.size() + + block_size = self.block_size + if max_seq_length is None: + max_seq_length = block_size + if varlen: + assert B == 1, "Varlen can be used only with batch_size==1" + cu_seqlens = self.prepare_cu_seqlens(idx[0], block_size, concat_token_id) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + assert max_seqlen <= max_seq_length, ( + f"Cannot forward sequence of length {max_seqlen}, max seq length is only {max_seq_length}" + ) + assert max_seqlen <= block_size, ( + f"Cannot forward sequence of length {max_seqlen}, block size is only {block_size}" + ) + else: + cu_seqlens = None + max_seqlen = None + assert T <= max_seq_length, ( + f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" + ) + assert T <= block_size, ( + f"Cannot forward sequence of length {T}, block size is only {block_size}" + ) + assert max_seq_length <= block_size, ( + f"Cannot attend to {max_seq_length}, block size is only {block_size}" + ) + + # forward the model itself + rope = None + if self.use_rope: + if rope_cache is None: + if self.rope_cache is None: + self.rope_cache = self.build_rope_cache( + dtype=torch.float32, device=torch.device("cpu") + ) + if idx.device != torch.device("meta"): + self.rope_cache = self.rope_cache.to(idx.device) + rope_cache = self.rope_cache + if input_pos is None: + rope = rope_cache[:T] + else: + rope = rope_cache[input_pos - T : input_pos] + + x_student, x_teacher = self.student_wte(idx), self.teacher_wte(idx) + + kd_block_loss = torch.zeros(B, device=idx.device, dtype=torch.float32) + for i_block, block in enumerate(self.blocks): + x_student, x_teacher = block(x_student, x_teacher, rope) + if self.kd_loss_weigher.kd_block_weight > 0: + curr_block_loss = self.block_loss_func(x_student, x_teacher) + kd_block_loss = kd_block_loss + curr_block_loss.float() / len(self.blocks) + + x_student, x_teacher = self.student_ln_f(x_student), self.teacher_ln_f(x_teacher) + logits_student, logits_teacher = ( + self.student_lm_head(x_student), + self.teacher_lm_head(x_teacher), + ) # (b, t, vocab_size) + if self.kd_loss_weigher.kd_logits_weight > 0: + kd_logits_loss = self.logits_loss_func(logits_student, logits_teacher) + else: + kd_logits_loss = torch.zeros(B, device=idx.device, dtype=torch.float32) + + return logits_student, logits_teacher, kd_block_loss, kd_logits_loss + + def train(self, mode: bool = True): + self.student_wte.train(mode) + self.student_ln_f.train(mode) + self.student_lm_head.train(mode) + for block in self.blocks: + block.student_block.train(mode) + return self + + +class KDBlock(nn.Module): + def __init__(self, student_block, teacher_block, teacher_requires_rope): + super().__init__() + self.student_block = student_block + self.teacher_block = teacher_block + self.teacher_requires_rope = teacher_requires_rope + + self.teacher_block.eval() + self.teacher_block.requires_grad_(False) + + def forward( + self, x_student: Tensor, x_teacher: Tensor, rope: Optional[RoPECache] + ) -> tuple[Tensor, Tensor]: + x_student = self.forward_block(x_student, self.student_block, rope=rope) + x_teacher = self.forward_block( + x_teacher, self.teacher_block, rope=rope if self.teacher_requires_rope else None + ) + + return x_student, x_teacher + + def forward_block(self, x: Tensor, block: Block, rope: Optional[RoPECache] = None) -> Tensor: + x = block(x=x, rope=rope, input_pos=None, kv_cache=None, is_decode=False) + return x diff --git a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py index 3c735b431..e4c5c6bfd 100644 --- a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py +++ b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py @@ -28,13 +28,24 @@ LMHead, ) from modelopt.torch._compress.tools.runtime import IRuntime -from sewing_kit import ExternalTarget, InputArgs, ModuleTarget, Needle, RemoteTarget, StitchedModule -from sewing_kit.core import InputReducer -from sewing_kit.utils import distributed_recv_obj, distributed_send_obj, fake_tensor +from modelopt.torch._compress.sewing_kit import ( + ExternalTarget, + InputArgs, + ModuleTarget, + Needle, + RemoteTarget, + StitchedModule, +) +from modelopt.torch._compress.sewing_kit.core import InputReducer +from modelopt.torch._compress.sewing_kit.utils import ( + distributed_recv_obj, + distributed_send_obj, + fake_tensor, +) from torch.utils.data import DataLoader from tqdm import tqdm from modelopt.torch._compress.tools.sharded_checkpoint_utils import DummyBlock -from utils.validation import _organize_outputs, calculate_batch_outputs +from modelopt.torch._compress.utils.validation import _organize_outputs, calculate_batch_outputs @torch.no_grad() diff --git a/modelopt/torch/_compress/utils/validation.py b/modelopt/torch/_compress/utils/validation.py index 256c10f6d..b1ec4b4fd 100644 --- a/modelopt/torch/_compress/utils/validation.py +++ b/modelopt/torch/_compress/utils/validation.py @@ -24,7 +24,7 @@ import torch.nn.functional as F import wandb from accelerate import Accelerator -from puzzle_tools import kd_model +from modelopt.torch._compress.tools import kd_model from torch import nn from torch.utils.data import DataLoader from tqdm import tqdm From ad84c2625bb8d5fdb5285687243e0ce6d71b066f Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 6 Nov 2025 10:15:54 +0100 Subject: [PATCH 36/49] fix imports Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/utils/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/_compress/utils/validation.py b/modelopt/torch/_compress/utils/validation.py index b1ec4b4fd..8434f95cf 100644 --- a/modelopt/torch/_compress/utils/validation.py +++ b/modelopt/torch/_compress/utils/validation.py @@ -30,7 +30,7 @@ from tqdm import tqdm from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper from typing_extensions import Self -from utils.data.dataloaders import create_padded_tensor +from modelopt.torch._compress.utils.data.dataloaders import create_padded_tensor @torch.no_grad() From 3d7e8a2e836b60fd672960a59fb3a425c9ec0d8d Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 6 Nov 2025 10:16:19 +0100 Subject: [PATCH 37/49] fix imports Signed-off-by: Daniel Korzekwa --- .../torch/_compress/tools/validate_model.py | 2 +- .../_compress/utils/checkpoint_manager.py | 272 ++++++++++++++++++ .../configs/validate_model_defaults.yaml | 2 +- 3 files changed, 274 insertions(+), 2 deletions(-) create mode 100644 modelopt/torch/_compress/utils/checkpoint_manager.py diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index 53d8ba176..51e0e7e0a 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -242,7 +242,7 @@ def validate_model( ) # Create checkpoint manager with hooks - from utils.checkpoint_manager import ScoringCheckpointManager + from modelopt.torch._compress.utils.checkpoint_manager import ScoringCheckpointManager mprint( f"Creating checkpoint manager with {len(activation_hooks)} hooks for dir: {args.activations_log_dir}" diff --git a/modelopt/torch/_compress/utils/checkpoint_manager.py b/modelopt/torch/_compress/utils/checkpoint_manager.py new file mode 100644 index 000000000..7bc3e73ef --- /dev/null +++ b/modelopt/torch/_compress/utils/checkpoint_manager.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Checkpoint manager for activation hook scoring with periodic saves and resume support. +""" + +import json +import time +from pathlib import Path +from typing import Dict, Any, Optional +from modelopt.torch._compress.tools.logger import mprint, aprint + + +class ScoringCheckpointManager: + """Manages checkpointing for activation hook scoring with periodic saves.""" + + def __init__( + self, checkpoint_dir: str, runtime, activation_hooks=None, checkpoint_interval: int = 100 + ): + """ + Initialize checkpoint manager. + + Args: + checkpoint_dir: Directory to save checkpoints + runtime: Runtime object for distributed processing + activation_hooks: Dictionary of activation hooks to manage + checkpoint_interval: Save checkpoint every N batches + """ + self.checkpoint_dir = Path(checkpoint_dir) + self.runtime = runtime + self.activation_hooks = activation_hooks + self.checkpoint_interval = checkpoint_interval + self.rank = runtime.global_rank if runtime is not None else 0 + self.is_main_process = runtime is None or runtime.is_main_process + + # Debug: Log checkpoint manager initialization + hook_count = len(activation_hooks) if activation_hooks else 0 + aprint( + f"[Rank {self.rank}] Checkpoint manager initialized: {hook_count} hooks, dir: {checkpoint_dir}" + ) + + # Checkpoint files + self.progress_file = self.checkpoint_dir / "scoring_progress.json" + self.hook_states_file = self.checkpoint_dir / f"hook_states_rank_{self.rank}.pth" + + # Progress tracking + self.current_batch = 0 + self.total_batches = 0 + self.start_time = time.time() + + # Ensure directory exists + if self.is_main_process: + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + def load_checkpoint(self) -> Optional[Dict[str, Any]]: + """ + Load existing checkpoint if available, including hook states. + + Returns: + Dict with checkpoint info or None if no checkpoint exists + """ + aprint(f"[Rank {self.rank}] Looking for checkpoint at: {self.progress_file}") + if not self.progress_file.exists(): + aprint(f"[Rank {self.rank}] No checkpoint file found at {self.progress_file}") + return None + + try: + with open(self.progress_file, "r") as f: + checkpoint_data = json.load(f) + + # Validate checkpoint + if "current_batch" in checkpoint_data and "total_batches" in checkpoint_data: + self.current_batch = checkpoint_data["current_batch"] + self.total_batches = checkpoint_data["total_batches"] + + mprint( + f"Found checkpoint: batch {self.current_batch}/{self.total_batches} ({checkpoint_data.get('progress', 0.0):.1%})" + ) + mprint( + f"Will resume from batch {self.current_batch}, skipping batches 0-{self.current_batch - 1}" + ) + + # Load hook states if hooks are available + if self.activation_hooks is not None: + success = self.load_hook_states(self.activation_hooks) + if success: + aprint( + f"[Rank {self.rank}] Successfully loaded hook states from checkpoint" + ) + else: + aprint(f"[Rank {self.rank}] Failed to load hook states - starting fresh") + + return checkpoint_data + else: + aprint( + f"[Rank {self.rank}] Invalid checkpoint format (missing current_batch/total_batches): {checkpoint_data}" + ) + return None + + except (json.JSONDecodeError, KeyError) as e: + mprint(f"Error loading checkpoint: {e}") + + return None + + def load_hook_states(self, activation_hooks) -> bool: + """ + Load hook states from checkpoint files. + + Args: + activation_hooks: Hook objects to load states into + + Returns: + bool: True if hook states were successfully loaded, False otherwise + """ + import os + + # Each rank loads only its own hook states + current_rank = int(os.environ.get("RANK", 0)) + hook_states_path = self.checkpoint_dir / f"hook_states_rank_{current_rank}.pth" + + if hook_states_path.exists(): + aprint(f"[Rank {current_rank}] Loading hook states from {hook_states_path}") + try: + import torch + + hook_states = torch.load(hook_states_path, map_location="cpu") + + # Load states into corresponding hooks + loaded_count = 0 + for module_name, hook in activation_hooks.items(): + if module_name in hook_states: + hook.load_state(hook_states[module_name]) + loaded_count += 1 + + # Log progress info if available (only for a few hooks to avoid spam) + if loaded_count <= 3: # Only log first few hooks + progress_info = hook.get_progress_info() + if progress_info: + aprint(f"[Rank {current_rank}] {module_name}: {progress_info}") + else: + aprint( + f"[Rank {current_rank}] Warning: No saved state found for hook: {module_name}" + ) + + aprint( + f"[Rank {current_rank}] Successfully loaded states for {loaded_count}/{len(activation_hooks)} hooks" + ) + return True + + except Exception as e: + aprint(f"[Rank {current_rank}] Error loading hook states: {e}") + return False + else: + aprint(f"[Rank {current_rank}] No hook states file found at {hook_states_path}") + return False + + def should_skip_batch(self, batch_idx: int) -> bool: + """Check if we should skip this batch (already processed in previous run).""" + should_skip = batch_idx < self.current_batch + if should_skip and batch_idx % 10 == 0: # Log every 10th skipped batch to avoid spam + mprint(f"Skipping batch {batch_idx} (resume from batch {self.current_batch})") + return should_skip + + def update_progress(self, batch_idx: int, total_batches: int): + """ + Update progress and potentially save checkpoint. + + Args: + batch_idx: Current batch index + total_batches: Total number of batches + """ + self.current_batch = batch_idx + self.total_batches = total_batches + + # Save checkpoint periodically or on completion + should_save = ( + (batch_idx % self.checkpoint_interval == 0) # Periodic save + or (batch_idx == total_batches - 1) # Final batch + ) + + if should_save: + # All ranks save their hook states + if self.activation_hooks is not None: + try: + from utils.activation_hooks.hooks import ActivationsHook + + saved_path = ActivationsHook.save_hook_states( + self.activation_hooks, self.checkpoint_dir, self.runtime + ) + except Exception as e: + mprint(f"Warning: Failed to save hook states: {e}") + + # Only main process saves progress info + if self.is_main_process: + self.save_checkpoint() + + # Synchronize all ranks after checkpointing + if self.runtime is not None: + self.runtime.wait_for_everyone() + + def save_checkpoint(self): + """ + Save current checkpoint to disk (progress info only). + Hook states are saved separately in update_progress. + """ + try: + # Save progress + progress_data = { + "current_batch": self.current_batch, + "total_batches": self.total_batches, + "progress": self.current_batch / self.total_batches + if self.total_batches > 0 + else 0.0, + "timestamp": time.time(), + "elapsed_time": time.time() - self.start_time, + "rank": self.rank, + } + + # Write progress atomically + temp_file = self.progress_file.with_suffix(".tmp") + with open(temp_file, "w") as f: + json.dump(progress_data, f, indent=2) + temp_file.replace(self.progress_file) + + # Hook states are saved at a higher level to ensure all ranks participate + + if self.current_batch % (self.checkpoint_interval) == 0: + progress_pct = progress_data["progress"] * 100 + elapsed = progress_data["elapsed_time"] + mprint( + f"Checkpoint saved: batch {self.current_batch}/{self.total_batches} ({progress_pct:.1f}%), elapsed: {elapsed:.1f}s" + ) + + except Exception as e: + mprint(f"Error saving checkpoint: {e}") + + def finalize(self): + """Mark scoring as completed.""" + # All ranks save their final hook states + if self.activation_hooks is not None: + try: + from utils.activation_hooks.hooks import ActivationsHook + + saved_path = ActivationsHook.save_hook_states( + self.activation_hooks, self.checkpoint_dir, self.runtime + ) + mprint(f"Final hook states saved to {saved_path}") + except Exception as e: + mprint(f"Warning: Failed to save final hook states: {e}") + + # Only main process saves progress info + if self.is_main_process: + self.current_batch = self.total_batches + self.save_checkpoint() + mprint(f"Scoring completed and finalized: {self.total_batches} batches processed") + + # Synchronize all ranks after finalization + if self.runtime is not None: + self.runtime.wait_for_everyone() diff --git a/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml b/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml index 046ff51f6..178edb50d 100644 --- a/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml +++ b/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml @@ -12,4 +12,4 @@ write_results: false calc_losses_on_cpu: false activations_log_dir: model_name_or_path: -load_dataset_fn: ${get_object:utils.data.dataloaders.load_from_disk_fn} +load_dataset_fn: ${get_object:modelopt.torch._compress.utils.data.dataloaders.load_from_disk_fn} From d541baac647fb5f49a0268443edad164742b5adf Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 20 Nov 2025 14:38:06 +0100 Subject: [PATCH 38/49] Delete not needed tokenizer. Signed-off-by: Daniel Korzekwa --- .../deci_lm_hf_code/tokenization_mistral.py | 374 ------------------ 1 file changed, 374 deletions(-) delete mode 100644 modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py diff --git a/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py b/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py deleted file mode 100644 index e67674a09..000000000 --- a/modelopt/torch/_compress/decilm/deci_lm_hf_code/tokenization_mistral.py +++ /dev/null @@ -1,374 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Based on https://github.com/vllm-project/vllm/blob/739e03b3449a7f3b0a81ebc30b9555305d914e2d/vllm/transformers_utils/tokenizers/mistral.py -# mypy: ignore-errors - -import os -import re -import sys -from pathlib import Path -from shutil import copyfile -from typing import TYPE_CHECKING, Any - -from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer -from transformers.utils import logging - -if TYPE_CHECKING: - from mistral_common.protocol.instruct.request import ChatCompletionRequest - -logger = logging.get_logger(__name__) - - -def _called_from_vllm() -> bool: - frame = sys._getframe(1) - while frame: - mod = frame.f_globals.get("__name__", "") - if mod == "vllm" or mod.startswith("vllm."): - return True - frame = frame.f_back - return False - - -class HFAdaptedMistralTokenizer(PreTrainedTokenizer): - """ - In order to save the tokenizer, do the following: - ``` - # from import HFAdaptedMistralTokenizer - # from mistral_common.tokens.tokenizers.base import SpecialTokens - HFAdaptedMistralTokenizer.register_for_auto_class("AutoTokenizer") - tokenizer = HFAdaptedMistralTokenizer("", chat_template="dummy") - tokenizer.add_special_tokens( - {"additional_special_tokens": [v.value for _, v in SpecialTokens.__members__.items()]} - ) - tokenizer.save_pretrained("") - ``` - """ - - vocab_files_names = {"path_indicator": "tokenizer_config.json"} - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - path_indicator: str, - unk_token: str | None = None, - bos_token: str | None = None, - eos_token: str | None = None, - pad_token: str | None = None, - add_bos_token: bool = True, - add_eos_token: bool = False, - clean_up_tokenization_spaces: bool = False, - **kwargs, - ): - path_indicator: Path = Path(path_indicator) - if path_indicator.name == "tokenizer_config.json": - path_indicator = path_indicator.parent - if path_indicator.is_dir(): - tokenizer_file_name = _find_tokenizer_file(os.listdir(path_indicator)) - tokenizer_file = str(path_indicator / tokenizer_file_name) - else: - tokenizer_file = path_indicator - self._mistral_tokenizer_path = str(tokenizer_file) - - from mistral_common.tokens.tokenizers.mistral import MistralTokenizer as MistralTokenizer - - self._mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file) - self._instruct_tokenizer = self._mistral_tokenizer.instruct_tokenizer - - # Copied from https://github.com/patrickvonplaten/vllm/blob/6cca3d8c330e169bbf386561c441ca5f3879cf85/vllm/transformers_utils/tokenizers/mistral.py - self.version: int = int( - self._instruct_tokenizer.tokenizer.version.value.split("v")[-1].split("m")[0] - ) - - tokenizer_ = self._instruct_tokenizer.tokenizer - from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer - - self.is_tekken = isinstance(tokenizer_, Tekkenizer) - from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer - - self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) - if self.is_tekken: - # Make sure special tokens will not raise - tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE - elif self.is_spm: - pass - else: - raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") - - self._vocab = tokenizer_.vocab() - # Convert to a Dict[str, int] to match protocol, but this is a lossy - # conversion. There may be multiple token ids that decode to the same - # string due to partial UTF-8 byte sequences being converted to � - self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)} - self._tokenizer = tokenizer_ - self._max_token_id = self.vocab_size - 1 - self.vocab = self._vocab_dict - - bos_token = ( - bos_token - if bos_token - else AddedToken( - self._tokenizer._vocab[self._tokenizer.bos_id], - normalized=False, - special=True, - ) - ) - eos_token = ( - eos_token - if eos_token - else AddedToken( - self._tokenizer._vocab[self._tokenizer.eos_id], - normalized=False, - special=True, - ) - ) - unk_token = ( - unk_token - if unk_token - else AddedToken( - self._tokenizer._vocab[self._tokenizer.unk_id], - normalized=False, - special=True, - ) - ) - pad_token = ( - pad_token - if pad_token - else AddedToken( - self._tokenizer._vocab[self._tokenizer.pad_id], - normalized=False, - special=True, - ) - ) - - self._add_bos_token = add_bos_token - self._add_eos_token = add_eos_token - - self._in_vllm = _called_from_vllm() - - super().__init__( - bos_token=bos_token, - eos_token=eos_token, - unk_token=unk_token, - pad_token=pad_token, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - **kwargs, - ) - - @property - def vocab_size(self): - """Returns vocab size""" - return self._tokenizer.n_words - - def get_vocab(self): - """Returns vocab as a dict""" - return self._vocab_dict - - def tokenize( - self, - text: str, - pair: str | None = None, - add_special_tokens: bool | None = None, - **kwargs, - ) -> list[str]: - from mistral_common.tokens.tokenizers.base import SpecialTokens - - if add_special_tokens is None: - bos = self._add_bos_token - eos = self._add_eos_token - else: - bos = add_special_tokens - eos = add_special_tokens - - input_ids = [] - parts = self.tokens_trie.split(text) - - in_vllm_chat_completion_mode = False - if ( - self._in_vllm - and len(parts) > 1 - and parts[0] == SpecialTokens.bos.value - and parts[1] == SpecialTokens.begin_inst.value - ): - # This is a dangerous hack to make the tokenizer work with vLLM. - # It means we are in chat completion mode. - bos = False - eos = False - in_vllm_chat_completion_mode = True - - if os.environ.get("HF_TOKENIZE_FORCE_NO_SPECIAL_TOKENS", "0") == "1": - bos = False - eos = False - - if not self._in_vllm or in_vllm_chat_completion_mode: - for part in parts: - if part in self.additional_special_tokens and part in self._vocab_dict: - input_ids.append(self._convert_token_to_id(part)) - else: - input_ids.extend(self._tokenizer.encode(part, bos=bos, eos=eos)) - else: - # Doesn't tokenize special tokens properly, but this is the behavior of vLLM when we are in completion mode. - input_ids = self._tokenizer.encode(text, bos=bos, eos=eos) - - if os.environ.get("HF_TOKENIZE_ABUSE", "1") == "1": - # A lot faster than the other option - return input_ids - else: - return [self._convert_id_to_token(token_id) for token_id in input_ids] - - def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: - if len(tokens) > 0 and isinstance(tokens[0], int): - return tokens - return super().convert_tokens_to_ids(tokens) - - def _convert_token_to_id(self, token): - """Converts a token (str) in an id using the vocab.""" - return self._vocab_dict[token] - - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" - piece = self._tokenizer.id_to_piece(index) - return piece if isinstance(piece, str) else piece.value - - def convert_tokens_to_string(self, tokens: list[str]) -> str: - from mistral_common.tokens.tokenizers.base import SpecialTokens - - if self.is_tekken: - tokens = [ - t - for t in tokens - if (t is SpecialTokens.tool_calls or t not in self._tokenizer._all_special_tokens) - ] - - if any(isinstance(t, bytes) for t in tokens): - # we need to encode and decode all tokens again - shift = self._tokenizer.num_special_tokens - - def _token_to_id(t: str): - t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t - try: - return shift + self._tokenizer._tekken_token2id_nospecial[t_bytes] - except KeyError: - logger.warning( - "Failed to convert token %s to id, replacing with ", - t_bytes, - ) - return self._tokenizer.unk_id - - ids = [_token_to_id(t) for t in tokens] - decoded = self._tokenizer.decode(ids) - else: - decoded = "".join(tokens) - else: - # make sure certain special tokens like Tool calls are - # not decoded - special_tokens = {SpecialTokens.tool_calls} - regular_tokens: list[str] = [] - decoded_list = [] - - for token in tokens: - if token in special_tokens: - if regular_tokens: - decoded_list.append(self._tokenizer.decode(regular_tokens)) - regular_tokens = [] - decoded_list.append(token) - else: - regular_tokens.append(token) - - if regular_tokens: - decoded_list.append(self._tokenizer.decode(regular_tokens)) # type: ignore[no-untyped-call] - - decoded = "".join(decoded_list) - - return decoded - - def save_vocabulary(self, save_directory, filename_prefix: str | None = None) -> tuple[str]: - """ - Use this method to save the full tokenizer file. - """ - - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return - out_vocab_file = os.path.join(save_directory, "tekken.json") - - if os.path.abspath(self._mistral_tokenizer_path) != os.path.abspath(out_vocab_file): - copyfile(self._mistral_tokenizer_path, out_vocab_file) - - return (out_vocab_file,) - - def apply_chat_template( - self, - conversation: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - tokenize: bool = True, - **kwargs, - ) -> list[int]: - request = _make_mistral_chat_completion_request(conversation, tools) - encoded = self._mistral_tokenizer.encode_chat_completion(request) - if tokenize: - # encode-decode to get clean prompt - return encoded.tokens - else: - return encoded.text - - -def _find_tokenizer_file(files: list[str]): - file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$") - - matched_files = [file for file in files if file_pattern.match(file)] - if len(matched_files) > 1: - raise OSError( - f"Found {len(matched_files)} files matching the " - f"pattern: `{file_pattern.pattern}`. Make sure only one Mistral " - f"tokenizer is present in {files}." - ) - elif len(matched_files) == 0: - raise OSError( - f"Found {len(matched_files)} files matching the " - f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral " - f"tokenizer is present in {files}." - ) - - return matched_files[0] - - -def _make_mistral_chat_completion_request( - messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None -) -> "ChatCompletionRequest": - last_message = messages[-1] - if last_message["role"] == "assistant": - last_message["prefix"] = True - - # mistral-common requires AssistantMessage content to be string [1]. - # - # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 - for message in messages: - if message.get("role") == "assistant": - content = message.get("content") - if isinstance(content, list): - content = "\n".join(chunk.get("text") for chunk in content) - message["content"] = content - - # The Mistral client, in comparison to the OpenAI client, requires the - # "parameters" dict to be present, even if it's empty. - if tools: - for function in [tool["function"] for tool in tools if tool["type"] == "function"]: - if function.get("parameters") is None: - function["parameters"] = {} - - from mistral_common.protocol.instruct.request import ChatCompletionRequest - - return ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var] From ae07708578a859f2a80bc0f71ecb078b3ffd8c85 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 20 Nov 2025 16:42:41 +0100 Subject: [PATCH 39/49] Fix imports Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/compress.py | 2 +- modelopt/torch/_compress/tools/__init__.py | 15 +++++++++++++++ .../_compress/tools/sharded_checkpoint_utils.py | 2 +- modelopt/torch/_compress/tools/validate_model.py | 11 +++++++---- 4 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 modelopt/torch/_compress/tools/__init__.py diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index 7d955c5ca..2271c7922 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -26,7 +26,7 @@ import score_pruning_activations import scoring from omegaconf import DictConfig -from puzzle_tools.runtime import IRuntime +from modelopt.torch._compress.tools.runtime import IRuntime from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir diff --git a/modelopt/torch/_compress/tools/__init__.py b/modelopt/torch/_compress/tools/__init__.py new file mode 100644 index 000000000..47f1c65a1 --- /dev/null +++ b/modelopt/torch/_compress/tools/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py index 5df79b79a..549ee9a88 100644 --- a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -30,7 +30,6 @@ import torch.nn as nn from huggingface_hub import split_torch_state_dict_into_shards from modelopt.torch._compress.tools.logger import mprint -from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file @@ -42,6 +41,7 @@ from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, + DeciLMForCausalLM, rope_type_to_class, ) from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index fe12a8c55..0e745a064 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -32,10 +32,13 @@ PreTrainedModel, PreTrainedTokenizerBase, ) -from utils.data.dataloaders import create_validation_dataloader -from utils.parsing import simple_parse_args_string -from utils.validate_runtime_pipeline import HiddenStatesAndLMHead, calculate_losses_pipeline -from utils.validation import calculate_losses +from modelopt.torch._compress.utils.data.dataloaders import create_validation_dataloader +from modelopt.torch._compress.utils.parsing import simple_parse_args_string +from modelopt.torch._compress.utils.validate_runtime_pipeline import ( + HiddenStatesAndLMHead, + calculate_losses_pipeline, +) +from modelopt.torch._compress.utils.validation import calculate_losses from modelopt.torch._compress.activation_scoring.activation_hooks.utils import ( register_activation_hooks, From a319d33799bc96aa6049e1f92a6ace4370ea4659 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 20 Nov 2025 17:42:25 +0100 Subject: [PATCH 40/49] fix imports Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/compress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index 2271c7922..64e241d10 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -23,7 +23,7 @@ import build_library_and_stats import mip_and_realize_models import pruning_ckpts -import score_pruning_activations +import modelopt.torch._compress.activation_scoring.score_pruning_activations as score_pruning_activations import scoring from omegaconf import DictConfig from modelopt.torch._compress.tools.runtime import IRuntime From a09e894a2755486af6285301262a0899c2dfb8aa Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 20 Nov 2025 17:59:18 +0100 Subject: [PATCH 41/49] fix imports Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/utils/checkpoint_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/_compress/utils/checkpoint_manager.py b/modelopt/torch/_compress/utils/checkpoint_manager.py index 7bc3e73ef..318586ba4 100644 --- a/modelopt/torch/_compress/utils/checkpoint_manager.py +++ b/modelopt/torch/_compress/utils/checkpoint_manager.py @@ -195,7 +195,9 @@ def update_progress(self, batch_idx: int, total_batches: int): # All ranks save their hook states if self.activation_hooks is not None: try: - from utils.activation_hooks.hooks import ActivationsHook + from modelopt.torch._compress.activation_scoring.activation_hooks.hooks import ( + ActivationsHook, + ) saved_path = ActivationsHook.save_hook_states( self.activation_hooks, self.checkpoint_dir, self.runtime @@ -252,7 +254,9 @@ def finalize(self): # All ranks save their final hook states if self.activation_hooks is not None: try: - from utils.activation_hooks.hooks import ActivationsHook + from modelopt.torch._compress.activation_scoring.activation_hooks.hooks import ( + ActivationsHook, + ) saved_path = ActivationsHook.save_hook_states( self.activation_hooks, self.checkpoint_dir, self.runtime From c69c4a95dd801edf9e5ac98dbb0219f85cac1b18 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 20 Nov 2025 18:05:29 +0100 Subject: [PATCH 42/49] Improve doc strings Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/utils/parsing.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modelopt/torch/_compress/utils/parsing.py b/modelopt/torch/_compress/utils/parsing.py index 6e880dd99..97f698ba9 100644 --- a/modelopt/torch/_compress/utils/parsing.py +++ b/modelopt/torch/_compress/utils/parsing.py @@ -12,6 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +Parsing and formatting utilities for configuration handling in model compression. + +This module provides utilities for: +- Parsing command-line arguments and configuration strings +- Formatting and displaying model configurations (block configs, attention, FFN) +- Formatting loss metrics for logging and visualization +""" # mypy: ignore-errors import json From 1ee802670af55f6c6fdddcc92a679555c7a7cca4 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 20 Nov 2025 18:16:34 +0100 Subject: [PATCH 43/49] Improve doc strings Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/utils/data/dataloaders.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modelopt/torch/_compress/utils/data/dataloaders.py b/modelopt/torch/_compress/utils/data/dataloaders.py index 4fc856fbb..4c4fce060 100644 --- a/modelopt/torch/_compress/utils/data/dataloaders.py +++ b/modelopt/torch/_compress/utils/data/dataloaders.py @@ -13,6 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +DataLoader utilities for language model training and validation. +""" + import os from collections.abc import Callable, Mapping, Sequence from functools import partial From d208f2d5a4c14b8f443c9e0acaad25d8841aaeb4 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 20 Nov 2025 18:33:33 +0100 Subject: [PATCH 44/49] Improve doc strings. Signed-off-by: Daniel Korzekwa --- .../torch/_compress/utils/validate_runtime_pipeline.py | 9 +++++++++ modelopt/torch/_compress/utils/validation.py | 8 ++++++++ 2 files changed, 17 insertions(+) diff --git a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py index e4c5c6bfd..08e1221a7 100644 --- a/modelopt/torch/_compress/utils/validate_runtime_pipeline.py +++ b/modelopt/torch/_compress/utils/validate_runtime_pipeline.py @@ -12,6 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. + +Coordinates forward passes and loss computation through model shards distributed across GPUs +using sewing_kit's StitchedModule framework. Relies on validation.py for core loss computation. + +Used by validate_model.py during activation scoring for sharded models. +""" # mypy: ignore-errors from statistics import mean diff --git a/modelopt/torch/_compress/utils/validation.py b/modelopt/torch/_compress/utils/validation.py index 8434f95cf..6072c2f16 100644 --- a/modelopt/torch/_compress/utils/validation.py +++ b/modelopt/torch/_compress/utils/validation.py @@ -12,6 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +Model validation and loss calculation utilities for single-GPU and multi-GPU setups. + +Also provides helper functions for loss metrics, KL divergence, JS divergence, +and similarity losses for knowledge distillation. +""" + # mypy: ignore-errors import functools import math From 7a89dfa107f65e2dbe1da359bab511ca022da34d Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 20 Nov 2025 19:11:33 +0100 Subject: [PATCH 45/49] Remove not used stuff from kd_model + add doc string Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/tools/kd_model.py | 304 +-------------------- 1 file changed, 7 insertions(+), 297 deletions(-) diff --git a/modelopt/torch/_compress/tools/kd_model.py b/modelopt/torch/_compress/tools/kd_model.py index cbff6aab3..437eb51ca 100644 --- a/modelopt/torch/_compress/tools/kd_model.py +++ b/modelopt/torch/_compress/tools/kd_model.py @@ -12,6 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +Knowledge distillation loss functions. + +Provides normalized_mse_loss and cosine_embedding_loss_batched for comparing +model outputs. Used by validation.py. +""" # mypy: ignore-errors from abc import ABCMeta, abstractmethod @@ -22,19 +29,6 @@ from torch import nn, Tensor -class Block(nn.Module): - def __init__(self, *args, **kwargs): - raise NotImplementedError("This class is deprecated. Deci models are now hf models.") - - -class DummyBlock(nn.Module): - def __init__(self, *args, **kwargs): - raise NotImplementedError("This class is deprecated. Deci models are now hf models.") - - -RoPECache = torch.Tensor - - def normalized_mse_loss( input: Tensor, target: Tensor, @@ -57,287 +51,3 @@ def cosine_embedding_loss_batched(input: Tensor, target: Tensor) -> Tensor: input1=input, input2=target, target=target_tensor, reduction="none" ) return loss - - -def cross_entropy_probs_batched(logits_input: Tensor, logits_target: Tensor) -> Tensor: - return F.cross_entropy( - logits_input.transpose(1, 2), logits_target.softmax(-1).transpose(1, 2), reduction="none" - ).mean(-1) - - -def kl_div_logits_batched(logits_input: Tensor, logits_target: Tensor) -> Tensor: - return ( - F.kl_div( - logits_input.log_softmax(-1), - logits_target.log_softmax(-1), - reduction="none", - log_target=True, - ) - .sum(-1) - .mean(-1) - ) - - -def kl_div_single_sample(logits_input: Tensor, logits_target: Tensor) -> Tensor: - return F.kl_div( - logits_input.log_softmax(-1), - logits_target.log_softmax(-1), - reduction="batchmean", - log_target=True, - ) - - -def kl_div_logits_batched_mem_efficient(logits_input: Tensor, logits_target: Tensor) -> Tensor: - batch_size = logits_input.shape[0] - kl_div_per_sample = [ - kl_div_single_sample(logits_input[i], logits_target[i]) for i in range(batch_size) - ] - return torch.stack(kl_div_per_sample) - - -kl_div = kl_div_logits_batched_mem_efficient - - -def mse_loss( - x_input: torch.Tensor, - x_target: torch.Tensor, -) -> Tensor: - return torch.stack( - [F.mse_loss(x_input[i_sample], x_target[i_sample]) for i_sample in range(x_input.shape[0])] - ) - - -def reverse_kl_div(logits_input: Tensor, logits_target: Tensor) -> Tensor: - return kl_div_logits_batched_mem_efficient(logits_target, logits_input) - - -def tv_dist(logits_input: Tensor, logits_target: Tensor) -> Tensor: - """ - Total Variation Distance: L1-loss between probabilities. - vocab dimension is summed, sequence dimension is averaged. - """ - batch_size, seq_len, vocab_size = logits_input.shape - tv_dist_per_sample = [ - F.l1_loss(logits_input[i].softmax(-1), logits_target[i].softmax(-1), reduction="sum") - / seq_len - for i in range(batch_size) - ] - return torch.stack(tv_dist_per_sample) - - -def js_div(logits_input: Tensor, logits_target: Tensor) -> Tensor: - """ - Jensen-Shannon Divergence for a single sample. - logits: [tokens, vocab] - target_probs: [tokens, vocab] - """ - batch_size = logits_input.shape[0] - _js_div = [] - for i in range(batch_size): - input_probs = logits_input[i].softmax(-1) - target_probs = logits_target[i].softmax(-1) - mixture_probs = (input_probs + target_probs) * 0.5 - mixture_logprobs = mixture_probs.log().clip(min=-20) - pred_kl_div = kl_div_single_sample(mixture_logprobs, input_probs) - target_kl_div = kl_div_single_sample(mixture_logprobs, target_probs) - js_div_i = 0.5 * (pred_kl_div + target_kl_div) - _js_div.append(js_div_i) - return torch.stack(_js_div) - - -LOGITS_LOSS_NAME_TO_FUNC = {name: func for name, func in globals().items() if callable(func)} - - -class KDLossWeigher(metaclass=ABCMeta): - @abstractmethod - def __call__( - self, - lm_loss: Tensor, - kd_block_loss: Tensor, - kd_logits_loss: Tensor, - ) -> Tensor: - raise NotImplementedError() - - -class StaticKDLossWeigher(KDLossWeigher): - def __init__( - self, - lm_weight: float, - kd_block_weight: float, - kd_logits_weight: float, - ): - self.lm_weight = lm_weight - self.kd_block_weight = kd_block_weight - self.kd_logits_weight = kd_logits_weight - - def __call__( - self, - lm_loss: Tensor, - kd_block_loss: Tensor, - kd_logits_loss: Tensor, - ) -> Tuple: - lm_loss = self.lm_weight * lm_loss - kd_block_loss = self.kd_block_weight * kd_block_loss - kd_logits_loss = self.kd_logits_weight * kd_logits_loss - - loss = lm_loss + kd_block_loss + kd_logits_loss - return loss, lm_loss, kd_block_loss, kd_logits_loss - - -class KDModel(nn.Module): - def __init__( - self, - student_model, - teacher_model, - block_loss_func: Callable, - logits_loss_func: Callable, - kd_loss_weigher: StaticKDLossWeigher, - teacher_requires_rope: bool = False, - ): - super().__init__() - assert not student_model.abs_positional - student_uses_rope = student_model.config.position_embedding_type in ["rope", "rope_llama4"] - teacher_uses_rope = teacher_model.config.position_embedding_type in ["rope", "rope_llama4"] - assert (student_uses_rope and teacher_uses_rope) or ( - not student_uses_rope and not teacher_uses_rope - ), "We do not support mixed rope usage" - self.use_rope = student_uses_rope - - self.logits_loss_func = logits_loss_func - self.block_loss_func = block_loss_func - # teacher_model.eval() - # teacher_model.requires_grad_(False) - self.student_wte, self.teacher_wte = ( - student_model.transformer.wte, - teacher_model.transformer.wte, - ) - self.blocks = nn.ModuleList() - student_blocks = student_model.transformer.h - teacher_blocks = teacher_model.transformer.h - for i in range(max(len(teacher_blocks), len(student_blocks))): - student_block = student_blocks[i] if i < len(student_blocks) else DummyBlock() - teacher_block = teacher_blocks[i] if i < len(teacher_blocks) else DummyBlock() - combined_block = KDBlock( - student_block, teacher_block, teacher_requires_rope=teacher_requires_rope - ) - self.blocks.append(combined_block) - self.student_ln_f, self.teacher_ln_f = ( - student_model.transformer.ln_f, - teacher_model.transformer.ln_f, - ) - self.student_lm_head, self.teacher_lm_head = student_model.lm_head, teacher_model.lm_head - self.block_size = student_model.block_size - - for teacher_module in (self.teacher_wte, self.teacher_ln_f, self.teacher_lm_head): - teacher_module.eval() - teacher_module.requires_grad_(False) - self.kd_loss_weigher = kd_loss_weigher - - def forward( - self, - idx: Tensor, - rope_cache: Optional[RoPECache] = None, - max_seq_length: Optional[int] = None, - varlen: bool = False, - concat_token_id: Optional[int] = None, - input_pos: Optional[int] = None, - kv_caches: Optional[List[torch.Tensor]] = None, - is_decode: bool = False, - ) -> tuple[Tensor, Tensor, Tensor]: - B, T = idx.size() - - block_size = self.block_size - if max_seq_length is None: - max_seq_length = block_size - if varlen: - assert B == 1, "Varlen can be used only with batch_size==1" - cu_seqlens = self.prepare_cu_seqlens(idx[0], block_size, concat_token_id) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - assert max_seqlen <= max_seq_length, ( - f"Cannot forward sequence of length {max_seqlen}, max seq length is only {max_seq_length}" - ) - assert max_seqlen <= block_size, ( - f"Cannot forward sequence of length {max_seqlen}, block size is only {block_size}" - ) - else: - cu_seqlens = None - max_seqlen = None - assert T <= max_seq_length, ( - f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" - ) - assert T <= block_size, ( - f"Cannot forward sequence of length {T}, block size is only {block_size}" - ) - assert max_seq_length <= block_size, ( - f"Cannot attend to {max_seq_length}, block size is only {block_size}" - ) - - # forward the model itself - rope = None - if self.use_rope: - if rope_cache is None: - if self.rope_cache is None: - self.rope_cache = self.build_rope_cache( - dtype=torch.float32, device=torch.device("cpu") - ) - if idx.device != torch.device("meta"): - self.rope_cache = self.rope_cache.to(idx.device) - rope_cache = self.rope_cache - if input_pos is None: - rope = rope_cache[:T] - else: - rope = rope_cache[input_pos - T : input_pos] - - x_student, x_teacher = self.student_wte(idx), self.teacher_wte(idx) - - kd_block_loss = torch.zeros(B, device=idx.device, dtype=torch.float32) - for i_block, block in enumerate(self.blocks): - x_student, x_teacher = block(x_student, x_teacher, rope) - if self.kd_loss_weigher.kd_block_weight > 0: - curr_block_loss = self.block_loss_func(x_student, x_teacher) - kd_block_loss = kd_block_loss + curr_block_loss.float() / len(self.blocks) - - x_student, x_teacher = self.student_ln_f(x_student), self.teacher_ln_f(x_teacher) - logits_student, logits_teacher = ( - self.student_lm_head(x_student), - self.teacher_lm_head(x_teacher), - ) # (b, t, vocab_size) - if self.kd_loss_weigher.kd_logits_weight > 0: - kd_logits_loss = self.logits_loss_func(logits_student, logits_teacher) - else: - kd_logits_loss = torch.zeros(B, device=idx.device, dtype=torch.float32) - - return logits_student, logits_teacher, kd_block_loss, kd_logits_loss - - def train(self, mode: bool = True): - self.student_wte.train(mode) - self.student_ln_f.train(mode) - self.student_lm_head.train(mode) - for block in self.blocks: - block.student_block.train(mode) - return self - - -class KDBlock(nn.Module): - def __init__(self, student_block, teacher_block, teacher_requires_rope): - super().__init__() - self.student_block = student_block - self.teacher_block = teacher_block - self.teacher_requires_rope = teacher_requires_rope - - self.teacher_block.eval() - self.teacher_block.requires_grad_(False) - - def forward( - self, x_student: Tensor, x_teacher: Tensor, rope: Optional[RoPECache] - ) -> tuple[Tensor, Tensor]: - x_student = self.forward_block(x_student, self.student_block, rope=rope) - x_teacher = self.forward_block( - x_teacher, self.teacher_block, rope=rope if self.teacher_requires_rope else None - ) - - return x_student, x_teacher - - def forward_block(self, x: Tensor, block: Block, rope: Optional[RoPECache] = None) -> Tensor: - x = block(x=x, rope=rope, input_pos=None, kv_cache=None, is_decode=False) - return x From 09ac4206512f933abae627b1adde466668fa1011 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 21 Nov 2025 16:38:40 +0100 Subject: [PATCH 46/49] Remove empty module Signed-off-by: Daniel Korzekwa --- .../sewing_kit/passage/recipes/__init__.py | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 modelopt/torch/_compress/sewing_kit/passage/recipes/__init__.py diff --git a/modelopt/torch/_compress/sewing_kit/passage/recipes/__init__.py b/modelopt/torch/_compress/sewing_kit/passage/recipes/__init__.py deleted file mode 100644 index 47f1c65a1..000000000 --- a/modelopt/torch/_compress/sewing_kit/passage/recipes/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - From 22a3afb4d6fbe0281ca9b9c2aed40381eb489277 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 24 Nov 2025 21:08:47 +0100 Subject: [PATCH 47/49] Update modelopt/torch/_compress/utils/data/dataset.py Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/utils/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/_compress/utils/data/dataset.py b/modelopt/torch/_compress/utils/data/dataset.py index 7398b378c..2c7fcef09 100644 --- a/modelopt/torch/_compress/utils/data/dataset.py +++ b/modelopt/torch/_compress/utils/data/dataset.py @@ -223,7 +223,7 @@ def prepare_cu_seqlens(self, input_ids): return cu_seqlens -## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py +## Adapted from https://github.com/NVIDIA/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py def permute( sample, np_rng, From 138d01e0527900a505806e25202655e021102411 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 24 Nov 2025 21:10:42 +0100 Subject: [PATCH 48/49] remove if-else check for pytorch <3.9 Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/sewing_kit/passage/core.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/modelopt/torch/_compress/sewing_kit/passage/core.py b/modelopt/torch/_compress/sewing_kit/passage/core.py index 9c200d569..4a66638aa 100644 --- a/modelopt/torch/_compress/sewing_kit/passage/core.py +++ b/modelopt/torch/_compress/sewing_kit/passage/core.py @@ -17,10 +17,7 @@ from __future__ import annotations import sys -if sys.version_info[1] < 9: # if less than pytorch 3.9 - from typing import Sequence, Callable -else: - from collections.abc import Sequence, Callable +from collections.abc import Sequence, Callable from dataclasses import dataclass from typing import Any, ContextManager, Iterable, Mapping, Optional, Union From 5a84ec9b10b2c60a2e391456e1d650ca69521d1e Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 24 Nov 2025 21:27:19 +0100 Subject: [PATCH 49/49] code clean up Signed-off-by: Daniel Korzekwa --- modelopt/torch/_compress/utils/validation.py | 9 +-------- tests/gpu/torch/export/test_fsdp2_export.py | 1 - 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/modelopt/torch/_compress/utils/validation.py b/modelopt/torch/_compress/utils/validation.py index 6072c2f16..63c664224 100644 --- a/modelopt/torch/_compress/utils/validation.py +++ b/modelopt/torch/_compress/utils/validation.py @@ -404,14 +404,7 @@ def _DEBUG_calculate_per_token_entropy(batch_outputs, logits, i_batch): # calculate the per token entropy and per token top p entropy = calc_entropy(logits).cpu() # .view(-1)#.tolist() msftm = confidence_max_softmax(logits).cpu() # .view(-1)#.tolist() - teacher_dir = ( - "/lustre/fsw/portfolios/coreai/projects/coreai_nvfm_llm/models/" - "meta-llama/Meta-Llama-3.1-70B-Instruct-new_rope/" - ) - # teacher_dir = ( - # '/lustre/fsw/portfolios/coreai/projects/coreai_nvfm_llm/models/' - # 'meta-llama/Meta-Llama-3.1-405B-Instruct/' - # ) + teacher_dir = ".../meta-llama/Meta-Llama-3.1-70B-Instruct-new_rope/" file_path = f"{teacher_dir}/validation/per_token_stats_{i_batch}.pth" os.makedirs(os.path.dirname(file_path), exist_ok=True) torch.save({"entropy": entropy, "max_softmax": msftm}, file_path) diff --git a/tests/gpu/torch/export/test_fsdp2_export.py b/tests/gpu/torch/export/test_fsdp2_export.py index 0c3496dec..690b363fa 100644 --- a/tests/gpu/torch/export/test_fsdp2_export.py +++ b/tests/gpu/torch/export/test_fsdp2_export.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations import copy from functools import partial