Skip to content

Commit 1ce483c

Browse files
[TRTLLM-7967][feat] Adding Starcoder2 PyTorch Backend Support (#8923)
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
1 parent 336593c commit 1ce483c

File tree

8 files changed

+662
-0
lines changed

8 files changed

+662
-0
lines changed

tensorrt_llm/_torch/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .modeling_qwen_moe import Qwen2MoeForCausalLM
3232
from .modeling_seedoss import SeedOssForCausalLM
3333
from .modeling_siglip import SiglipVisionModel
34+
from .modeling_starcoder2 import Starcoder2ForCausalLM
3435
from .modeling_utils import get_model_architecture
3536
from .modeling_vila import VilaModel
3637

@@ -62,6 +63,7 @@
6263
"Qwen2ForRewardModel",
6364
"Qwen2MoeForCausalLM",
6465
"SiglipVisionModel",
66+
"Starcoder2ForCausalLM",
6567
"get_model_architecture",
6668
"VilaModel",
6769
"Qwen2VLModel",
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Optional
17+
18+
import torch
19+
from torch import nn
20+
from transformers import Starcoder2Config
21+
22+
from tensorrt_llm._torch.attention_backend import AttentionMetadata
23+
from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams
24+
from tensorrt_llm._torch.model_config import ModelConfig
25+
from tensorrt_llm._torch.models.modeling_utils import (
26+
DecoderModel,
27+
DecoderModelForCausalLM,
28+
_load_weights_impl,
29+
register_auto_model,
30+
)
31+
from tensorrt_llm._torch.modules.attention import Attention
32+
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
33+
from tensorrt_llm._torch.modules.embedding import Embedding
34+
from tensorrt_llm._torch.modules.layer_norm import LayerNorm
35+
from tensorrt_llm._torch.modules.linear import TensorParallelMode
36+
from tensorrt_llm._torch.modules.mlp import MLP
37+
from tensorrt_llm._torch.speculative import SpecMetadata
38+
from tensorrt_llm.functional import PositionEmbeddingType
39+
40+
41+
class Starcoder2Attention(Attention):
42+
"""
43+
StarCoder2 Attention with Grouped Query Attention and Sliding Window support.
44+
"""
45+
46+
def __init__(
47+
self,
48+
model_config: ModelConfig[Starcoder2Config],
49+
layer_idx: Optional[int] = None,
50+
):
51+
config = model_config.pretrained_config
52+
super().__init__(
53+
hidden_size=config.hidden_size,
54+
num_attention_heads=config.num_attention_heads,
55+
num_key_value_heads=config.num_key_value_heads,
56+
max_position_embeddings=config.max_position_embeddings,
57+
bias=config.use_bias,
58+
pos_embd_params=PositionalEmbeddingParams(
59+
type=PositionEmbeddingType.rope_gpt_neox,
60+
rope=RopeParams.from_config(config),
61+
),
62+
layer_idx=layer_idx,
63+
dtype=config.torch_dtype,
64+
config=model_config,
65+
)
66+
67+
# Configure sliding window attention (4096 tokens)
68+
self.attention_window_size = getattr(config, "sliding_window", 4096)
69+
70+
def forward(
71+
self,
72+
position_ids: torch.IntTensor,
73+
hidden_states: torch.Tensor,
74+
attn_metadata: AttentionMetadata,
75+
**kwargs,
76+
) -> torch.Tensor:
77+
"""
78+
Overrides parent to pass attention_window_size parameter.
79+
"""
80+
return super().forward(
81+
position_ids=position_ids,
82+
hidden_states=hidden_states,
83+
attn_metadata=attn_metadata,
84+
attention_window_size=self.attention_window_size,
85+
**kwargs,
86+
)
87+
88+
89+
class Starcoder2DecoderLayer(DecoderLayer):
90+
"""
91+
StarCoder2 Decoder Layer.
92+
93+
Architecture:
94+
- Layer normalization before attention (with bias)
95+
- Self-attention with GQA and sliding window
96+
- Layer normalization before MLP (with bias)
97+
- MLP with GELU activation
98+
"""
99+
100+
def __init__(
101+
self,
102+
model_config: ModelConfig[Starcoder2Config],
103+
layer_idx: int,
104+
):
105+
super().__init__()
106+
config = model_config.pretrained_config
107+
self.layer_idx = layer_idx
108+
109+
self.self_attn = Starcoder2Attention(
110+
model_config,
111+
layer_idx=layer_idx,
112+
)
113+
114+
if config.mlp_type == "default":
115+
self.mlp = MLP(
116+
hidden_size=config.hidden_size,
117+
intermediate_size=config.intermediate_size,
118+
bias=config.use_bias,
119+
activation=nn.GELU(),
120+
dtype=config.torch_dtype,
121+
config=model_config,
122+
)
123+
else:
124+
raise ValueError(
125+
f"Unsupported mlp_type: {config.mlp_type}. Only default (linear) MLP is supported."
126+
)
127+
128+
norm_eps = getattr(config, "norm_epsilon", 1e-5)
129+
self.input_layernorm = LayerNorm(
130+
hidden_size=config.hidden_size,
131+
eps=norm_eps,
132+
dtype=config.torch_dtype,
133+
has_bias=True, # StarCoder2 uses bias in layer norm
134+
)
135+
136+
self.post_attention_layernorm = LayerNorm(
137+
hidden_size=config.hidden_size,
138+
eps=norm_eps,
139+
dtype=config.torch_dtype,
140+
has_bias=True, # StarCoder2 uses bias in layer norm
141+
)
142+
143+
def forward(
144+
self,
145+
position_ids: torch.IntTensor,
146+
hidden_states: torch.Tensor,
147+
attn_metadata: AttentionMetadata,
148+
residual: Optional[torch.Tensor] = None,
149+
spec_metadata: Optional[SpecMetadata] = None,
150+
**kwargs,
151+
):
152+
if residual is None:
153+
residual = hidden_states
154+
hidden_states = self.input_layernorm(hidden_states)
155+
else:
156+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
157+
158+
# Self Attention
159+
hidden_states = self.self_attn(
160+
position_ids=position_ids,
161+
hidden_states=hidden_states,
162+
attn_metadata=attn_metadata,
163+
**kwargs,
164+
)
165+
166+
# Fully Connected (MLP)
167+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
168+
hidden_states = self.mlp(hidden_states)
169+
170+
if spec_metadata is not None:
171+
spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual)
172+
173+
return hidden_states, residual
174+
175+
176+
class Starcoder2Model(DecoderModel):
177+
"""
178+
StarCoder2 Transformer Model.
179+
"""
180+
181+
def __init__(self, model_config: ModelConfig[Starcoder2Config]):
182+
super().__init__(model_config)
183+
config = self.model_config.pretrained_config
184+
185+
self.embed_tokens = Embedding(
186+
config.vocab_size,
187+
config.hidden_size,
188+
dtype=config.torch_dtype,
189+
mapping=model_config.mapping,
190+
tensor_parallel_mode=TensorParallelMode.COLUMN,
191+
gather_output=True,
192+
)
193+
194+
self.layers = nn.ModuleList(
195+
[
196+
Starcoder2DecoderLayer(
197+
model_config,
198+
layer_idx,
199+
)
200+
for layer_idx in range(config.num_hidden_layers)
201+
]
202+
)
203+
204+
# Use norm_epsilon (Starcoder2Config attribute name)
205+
norm_eps = getattr(config, "norm_epsilon", 1e-5)
206+
self.norm = LayerNorm(
207+
hidden_size=config.hidden_size,
208+
eps=norm_eps,
209+
dtype=config.torch_dtype,
210+
has_bias=True, # StarCoder2 uses bias in layer norm
211+
)
212+
213+
def forward(
214+
self,
215+
attn_metadata: AttentionMetadata,
216+
input_ids: Optional[torch.IntTensor] = None,
217+
position_ids: Optional[torch.IntTensor] = None,
218+
inputs_embeds: Optional[torch.FloatTensor] = None,
219+
spec_metadata: Optional[SpecMetadata] = None,
220+
lora_params=None,
221+
) -> torch.Tensor:
222+
if (input_ids is None) ^ (inputs_embeds is not None):
223+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
224+
225+
if inputs_embeds is None:
226+
inputs_embeds = self.embed_tokens(input_ids)
227+
228+
hidden_states = inputs_embeds
229+
230+
residual = None
231+
for decoder_layer in self.layers:
232+
hidden_states, residual = decoder_layer(
233+
position_ids=position_ids,
234+
hidden_states=hidden_states,
235+
attn_metadata=attn_metadata,
236+
residual=residual,
237+
spec_metadata=spec_metadata,
238+
lora_params=lora_params,
239+
)
240+
241+
# Use LayerNorm's built-in residual connection support
242+
hidden_states, _ = self.norm(hidden_states, residual)
243+
return hidden_states
244+
245+
246+
@register_auto_model("Starcoder2ForCausalLM")
247+
class Starcoder2ForCausalLM(DecoderModelForCausalLM[Starcoder2Model, Starcoder2Config]):
248+
def __init__(
249+
self,
250+
model_config: ModelConfig[Starcoder2Config],
251+
):
252+
# Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16).
253+
# For the 15B FP32 checkpoint, we cast it to bfloat16 for consistency.
254+
torch_dtype_to_check = model_config.pretrained_config.torch_dtype
255+
if torch_dtype_to_check is None or torch_dtype_to_check == torch.float32:
256+
model_config.pretrained_config.torch_dtype = torch.bfloat16
257+
258+
super().__init__(
259+
Starcoder2Model(model_config),
260+
config=model_config,
261+
hidden_size=model_config.pretrained_config.hidden_size,
262+
vocab_size=model_config.pretrained_config.vocab_size,
263+
)
264+
265+
def load_weights(self, weights, weight_mapper=None, skip_modules=None):
266+
"""
267+
Load weights with custom mapping for StarCoder2.
268+
269+
StarCoder2 uses GPT-2 style MLP naming (c_fc, c_proj)
270+
while our MLP module expects (up_proj, down_proj).
271+
"""
272+
if skip_modules is None:
273+
skip_modules = []
274+
275+
# Map HuggingFace StarCoder2 weight names to TensorRT-LLM names
276+
params_map = {
277+
r"(.*?)\.mlp\.c_fc\.(.*)": r"\1.mlp.up_proj.\2",
278+
r"(.*?)\.mlp\.c_proj\.(.*)": r"\1.mlp.down_proj.\2",
279+
}
280+
preload_weight_modules = getattr(self, "preload_weight_modules", None)
281+
_load_weights_impl(
282+
self,
283+
weights,
284+
skip_modules,
285+
params_map=params_map,
286+
preload_weight_modules=preload_weight_modules,
287+
)

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,9 @@ zai-org/GLM-4.6:
270270
- quant_algo: NVFP4
271271
spec_dec_algo: MTP
272272
accuracy: 88.0
273+
bigcode/starcoder2-3b:
274+
- accuracy: 20.2
275+
bigcode/starcoder2-7b:
276+
- accuracy: 26.5
277+
bigcode/starcoder2-15b:
278+
- accuracy: 54.5

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4264,3 +4264,49 @@ def test_nvfp4_4gpus(self):
42644264
if temp_dir and os.path.exists(temp_dir):
42654265
import shutil
42664266
shutil.rmtree(temp_dir, ignore_errors=True)
4267+
4268+
4269+
class TestStarcoder2_3B(LlmapiAccuracyTestHarness):
4270+
MODEL_NAME = "bigcode/starcoder2-3b"
4271+
MODEL_PATH = f"{llm_models_root()}/starcoder2-3b/"
4272+
4273+
@skip_pre_hopper
4274+
def test_auto_dtype(self):
4275+
with LLM(self.MODEL_PATH,
4276+
attn_backend="TRTLLM",
4277+
cuda_graph_config=None,
4278+
max_batch_size=128,
4279+
max_seq_len=4096) as llm:
4280+
task = GSM8K(self.MODEL_NAME)
4281+
task.evaluate(llm)
4282+
4283+
4284+
class TestStarcoder2_7B(LlmapiAccuracyTestHarness):
4285+
MODEL_NAME = "bigcode/starcoder2-7b"
4286+
MODEL_PATH = f"{llm_models_root()}/starcoder2-7b/"
4287+
4288+
@skip_pre_hopper
4289+
def test_auto_dtype(self):
4290+
with LLM(self.MODEL_PATH,
4291+
attn_backend="TRTLLM",
4292+
cuda_graph_config=None,
4293+
max_batch_size=128,
4294+
max_seq_len=4096) as llm:
4295+
task = GSM8K(self.MODEL_NAME)
4296+
task.evaluate(llm)
4297+
4298+
4299+
class TestStarcoder2_15B(LlmapiAccuracyTestHarness):
4300+
MODEL_NAME = "bigcode/starcoder2-15b"
4301+
MODEL_PATH = f"{llm_models_root()}/starcoder2-15b/"
4302+
4303+
@skip_pre_hopper
4304+
@pytest.mark.skip_less_device_memory(80000)
4305+
def test_auto_dtype(self):
4306+
with LLM(self.MODEL_PATH,
4307+
attn_backend="TRTLLM",
4308+
cuda_graph_config=None,
4309+
max_batch_size=128,
4310+
max_seq_len=4096) as llm:
4311+
task = GSM8K(self.MODEL_NAME)
4312+
task.evaluate(llm)

tests/integration/test_lists/qa/llm_function_nim.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,10 @@ accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-c
381381
accuracy/test_llm_api_pytorch.py::TestQwQ_32B::test_auto_dtype_tp4
382382
accuracy/test_llm_api_pytorch.py::TestCodestral_22B_V01::test_auto_dtype
383383
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]
384+
accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
385+
accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
386+
accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype
387+
384388
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
385389
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype
386390
accuracy/test_llm_api_pytorch_multimodal.py::TestLlava_V1_6_Mistral_7B::test_auto_dtype

tests/integration/test_lists/test-db/l0_a30.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ l0_a30:
1919
- unittest/_torch/modeling -k "modeling_qwen"
2020
- unittest/_torch/modeling -k "modeling_qwen_moe"
2121
- unittest/_torch/modeling -k "modeling_out_of_tree"
22+
- unittest/_torch/modeling -k "modeling_starcoder2"
2223
- unittest/_torch/auto_deploy/unit/singlegpu
2324
- unittest/_torch/sampler/test_beam_search.py
2425
- unittest/_torch/sampler/test_return_logits.py

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,9 @@ l0_h100:
265265
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=False]
266266
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
267267
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
268+
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
269+
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
270+
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype
268271
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
269272
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
270273
- test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image]

0 commit comments

Comments
 (0)