|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +from typing import Optional |
| 17 | + |
| 18 | +import torch |
| 19 | +from torch import nn |
| 20 | +from transformers import Starcoder2Config |
| 21 | + |
| 22 | +from tensorrt_llm._torch.attention_backend import AttentionMetadata |
| 23 | +from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams |
| 24 | +from tensorrt_llm._torch.model_config import ModelConfig |
| 25 | +from tensorrt_llm._torch.models.modeling_utils import ( |
| 26 | + DecoderModel, |
| 27 | + DecoderModelForCausalLM, |
| 28 | + _load_weights_impl, |
| 29 | + register_auto_model, |
| 30 | +) |
| 31 | +from tensorrt_llm._torch.modules.attention import Attention |
| 32 | +from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer |
| 33 | +from tensorrt_llm._torch.modules.embedding import Embedding |
| 34 | +from tensorrt_llm._torch.modules.layer_norm import LayerNorm |
| 35 | +from tensorrt_llm._torch.modules.linear import TensorParallelMode |
| 36 | +from tensorrt_llm._torch.modules.mlp import MLP |
| 37 | +from tensorrt_llm._torch.speculative import SpecMetadata |
| 38 | +from tensorrt_llm.functional import PositionEmbeddingType |
| 39 | + |
| 40 | + |
| 41 | +class Starcoder2Attention(Attention): |
| 42 | + """ |
| 43 | + StarCoder2 Attention with Grouped Query Attention and Sliding Window support. |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + model_config: ModelConfig[Starcoder2Config], |
| 49 | + layer_idx: Optional[int] = None, |
| 50 | + ): |
| 51 | + config = model_config.pretrained_config |
| 52 | + super().__init__( |
| 53 | + hidden_size=config.hidden_size, |
| 54 | + num_attention_heads=config.num_attention_heads, |
| 55 | + num_key_value_heads=config.num_key_value_heads, |
| 56 | + max_position_embeddings=config.max_position_embeddings, |
| 57 | + bias=config.use_bias, |
| 58 | + pos_embd_params=PositionalEmbeddingParams( |
| 59 | + type=PositionEmbeddingType.rope_gpt_neox, |
| 60 | + rope=RopeParams.from_config(config), |
| 61 | + ), |
| 62 | + layer_idx=layer_idx, |
| 63 | + dtype=config.torch_dtype, |
| 64 | + config=model_config, |
| 65 | + ) |
| 66 | + |
| 67 | + # Configure sliding window attention (4096 tokens) |
| 68 | + self.attention_window_size = getattr(config, "sliding_window", 4096) |
| 69 | + |
| 70 | + def forward( |
| 71 | + self, |
| 72 | + position_ids: torch.IntTensor, |
| 73 | + hidden_states: torch.Tensor, |
| 74 | + attn_metadata: AttentionMetadata, |
| 75 | + **kwargs, |
| 76 | + ) -> torch.Tensor: |
| 77 | + """ |
| 78 | + Overrides parent to pass attention_window_size parameter. |
| 79 | + """ |
| 80 | + return super().forward( |
| 81 | + position_ids=position_ids, |
| 82 | + hidden_states=hidden_states, |
| 83 | + attn_metadata=attn_metadata, |
| 84 | + attention_window_size=self.attention_window_size, |
| 85 | + **kwargs, |
| 86 | + ) |
| 87 | + |
| 88 | + |
| 89 | +class Starcoder2DecoderLayer(DecoderLayer): |
| 90 | + """ |
| 91 | + StarCoder2 Decoder Layer. |
| 92 | +
|
| 93 | + Architecture: |
| 94 | + - Layer normalization before attention (with bias) |
| 95 | + - Self-attention with GQA and sliding window |
| 96 | + - Layer normalization before MLP (with bias) |
| 97 | + - MLP with GELU activation |
| 98 | + """ |
| 99 | + |
| 100 | + def __init__( |
| 101 | + self, |
| 102 | + model_config: ModelConfig[Starcoder2Config], |
| 103 | + layer_idx: int, |
| 104 | + ): |
| 105 | + super().__init__() |
| 106 | + config = model_config.pretrained_config |
| 107 | + self.layer_idx = layer_idx |
| 108 | + |
| 109 | + self.self_attn = Starcoder2Attention( |
| 110 | + model_config, |
| 111 | + layer_idx=layer_idx, |
| 112 | + ) |
| 113 | + |
| 114 | + if config.mlp_type == "default": |
| 115 | + self.mlp = MLP( |
| 116 | + hidden_size=config.hidden_size, |
| 117 | + intermediate_size=config.intermediate_size, |
| 118 | + bias=config.use_bias, |
| 119 | + activation=nn.GELU(), |
| 120 | + dtype=config.torch_dtype, |
| 121 | + config=model_config, |
| 122 | + ) |
| 123 | + else: |
| 124 | + raise ValueError( |
| 125 | + f"Unsupported mlp_type: {config.mlp_type}. Only default (linear) MLP is supported." |
| 126 | + ) |
| 127 | + |
| 128 | + norm_eps = getattr(config, "norm_epsilon", 1e-5) |
| 129 | + self.input_layernorm = LayerNorm( |
| 130 | + hidden_size=config.hidden_size, |
| 131 | + eps=norm_eps, |
| 132 | + dtype=config.torch_dtype, |
| 133 | + has_bias=True, # StarCoder2 uses bias in layer norm |
| 134 | + ) |
| 135 | + |
| 136 | + self.post_attention_layernorm = LayerNorm( |
| 137 | + hidden_size=config.hidden_size, |
| 138 | + eps=norm_eps, |
| 139 | + dtype=config.torch_dtype, |
| 140 | + has_bias=True, # StarCoder2 uses bias in layer norm |
| 141 | + ) |
| 142 | + |
| 143 | + def forward( |
| 144 | + self, |
| 145 | + position_ids: torch.IntTensor, |
| 146 | + hidden_states: torch.Tensor, |
| 147 | + attn_metadata: AttentionMetadata, |
| 148 | + residual: Optional[torch.Tensor] = None, |
| 149 | + spec_metadata: Optional[SpecMetadata] = None, |
| 150 | + **kwargs, |
| 151 | + ): |
| 152 | + if residual is None: |
| 153 | + residual = hidden_states |
| 154 | + hidden_states = self.input_layernorm(hidden_states) |
| 155 | + else: |
| 156 | + hidden_states, residual = self.input_layernorm(hidden_states, residual) |
| 157 | + |
| 158 | + # Self Attention |
| 159 | + hidden_states = self.self_attn( |
| 160 | + position_ids=position_ids, |
| 161 | + hidden_states=hidden_states, |
| 162 | + attn_metadata=attn_metadata, |
| 163 | + **kwargs, |
| 164 | + ) |
| 165 | + |
| 166 | + # Fully Connected (MLP) |
| 167 | + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) |
| 168 | + hidden_states = self.mlp(hidden_states) |
| 169 | + |
| 170 | + if spec_metadata is not None: |
| 171 | + spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual) |
| 172 | + |
| 173 | + return hidden_states, residual |
| 174 | + |
| 175 | + |
| 176 | +class Starcoder2Model(DecoderModel): |
| 177 | + """ |
| 178 | + StarCoder2 Transformer Model. |
| 179 | + """ |
| 180 | + |
| 181 | + def __init__(self, model_config: ModelConfig[Starcoder2Config]): |
| 182 | + super().__init__(model_config) |
| 183 | + config = self.model_config.pretrained_config |
| 184 | + |
| 185 | + self.embed_tokens = Embedding( |
| 186 | + config.vocab_size, |
| 187 | + config.hidden_size, |
| 188 | + dtype=config.torch_dtype, |
| 189 | + mapping=model_config.mapping, |
| 190 | + tensor_parallel_mode=TensorParallelMode.COLUMN, |
| 191 | + gather_output=True, |
| 192 | + ) |
| 193 | + |
| 194 | + self.layers = nn.ModuleList( |
| 195 | + [ |
| 196 | + Starcoder2DecoderLayer( |
| 197 | + model_config, |
| 198 | + layer_idx, |
| 199 | + ) |
| 200 | + for layer_idx in range(config.num_hidden_layers) |
| 201 | + ] |
| 202 | + ) |
| 203 | + |
| 204 | + # Use norm_epsilon (Starcoder2Config attribute name) |
| 205 | + norm_eps = getattr(config, "norm_epsilon", 1e-5) |
| 206 | + self.norm = LayerNorm( |
| 207 | + hidden_size=config.hidden_size, |
| 208 | + eps=norm_eps, |
| 209 | + dtype=config.torch_dtype, |
| 210 | + has_bias=True, # StarCoder2 uses bias in layer norm |
| 211 | + ) |
| 212 | + |
| 213 | + def forward( |
| 214 | + self, |
| 215 | + attn_metadata: AttentionMetadata, |
| 216 | + input_ids: Optional[torch.IntTensor] = None, |
| 217 | + position_ids: Optional[torch.IntTensor] = None, |
| 218 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 219 | + spec_metadata: Optional[SpecMetadata] = None, |
| 220 | + lora_params=None, |
| 221 | + ) -> torch.Tensor: |
| 222 | + if (input_ids is None) ^ (inputs_embeds is not None): |
| 223 | + raise ValueError("You must specify exactly one of input_ids or inputs_embeds.") |
| 224 | + |
| 225 | + if inputs_embeds is None: |
| 226 | + inputs_embeds = self.embed_tokens(input_ids) |
| 227 | + |
| 228 | + hidden_states = inputs_embeds |
| 229 | + |
| 230 | + residual = None |
| 231 | + for decoder_layer in self.layers: |
| 232 | + hidden_states, residual = decoder_layer( |
| 233 | + position_ids=position_ids, |
| 234 | + hidden_states=hidden_states, |
| 235 | + attn_metadata=attn_metadata, |
| 236 | + residual=residual, |
| 237 | + spec_metadata=spec_metadata, |
| 238 | + lora_params=lora_params, |
| 239 | + ) |
| 240 | + |
| 241 | + # Use LayerNorm's built-in residual connection support |
| 242 | + hidden_states, _ = self.norm(hidden_states, residual) |
| 243 | + return hidden_states |
| 244 | + |
| 245 | + |
| 246 | +@register_auto_model("Starcoder2ForCausalLM") |
| 247 | +class Starcoder2ForCausalLM(DecoderModelForCausalLM[Starcoder2Model, Starcoder2Config]): |
| 248 | + def __init__( |
| 249 | + self, |
| 250 | + model_config: ModelConfig[Starcoder2Config], |
| 251 | + ): |
| 252 | + # Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16). |
| 253 | + # For the 15B FP32 checkpoint, we cast it to bfloat16 for consistency. |
| 254 | + torch_dtype_to_check = model_config.pretrained_config.torch_dtype |
| 255 | + if torch_dtype_to_check is None or torch_dtype_to_check == torch.float32: |
| 256 | + model_config.pretrained_config.torch_dtype = torch.bfloat16 |
| 257 | + |
| 258 | + super().__init__( |
| 259 | + Starcoder2Model(model_config), |
| 260 | + config=model_config, |
| 261 | + hidden_size=model_config.pretrained_config.hidden_size, |
| 262 | + vocab_size=model_config.pretrained_config.vocab_size, |
| 263 | + ) |
| 264 | + |
| 265 | + def load_weights(self, weights, weight_mapper=None, skip_modules=None): |
| 266 | + """ |
| 267 | + Load weights with custom mapping for StarCoder2. |
| 268 | +
|
| 269 | + StarCoder2 uses GPT-2 style MLP naming (c_fc, c_proj) |
| 270 | + while our MLP module expects (up_proj, down_proj). |
| 271 | + """ |
| 272 | + if skip_modules is None: |
| 273 | + skip_modules = [] |
| 274 | + |
| 275 | + # Map HuggingFace StarCoder2 weight names to TensorRT-LLM names |
| 276 | + params_map = { |
| 277 | + r"(.*?)\.mlp\.c_fc\.(.*)": r"\1.mlp.up_proj.\2", |
| 278 | + r"(.*?)\.mlp\.c_proj\.(.*)": r"\1.mlp.down_proj.\2", |
| 279 | + } |
| 280 | + preload_weight_modules = getattr(self, "preload_weight_modules", None) |
| 281 | + _load_weights_impl( |
| 282 | + self, |
| 283 | + weights, |
| 284 | + skip_modules, |
| 285 | + params_map=params_map, |
| 286 | + preload_weight_modules=preload_weight_modules, |
| 287 | + ) |
0 commit comments