Skip to content
2 changes: 1 addition & 1 deletion fast_llm/engine/evaluation/lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setup(

self._flm_wrapper = FastLLMLmEvalWrapper(
model=self._hf_model,
tokenizer=self._config.tokenizer.get_tokenizer(),
tokenizer=self._config.tokenizer.get_tokenizer().tokenizer,
truncation=self._config.truncation,
logits_cache=self._config.logits_cache,
add_bos_token=self._config.add_bos_token,
Expand Down
4 changes: 4 additions & 0 deletions fast_llm/engine/inference/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,7 @@ def stop_workers(self):
def inner_forward(*args, **kwargs) -> tuple | transformers.utils.generic.ModelOutput:
# Meant to be overridden in derived classes
raise NotImplementedError()

@classmethod
def can_generate(cls) -> bool:
return True
69 changes: 69 additions & 0 deletions fast_llm/models/gpt/conversion/qwen2.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,44 @@
import typing

from fast_llm.engine.checkpoint.config import CheckpointFormat
from fast_llm.engine.checkpoint.external import WeightConverter
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.decoder.mlp.config import MLPConfig
from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat
from fast_llm.models.gpt.conversion.llama import (
KeyValueWeightConverter,
LlamaAttentionConverter,
LlamaBaseModelConverter,
LlamaBlockConverter,
LlamaDecoderConverter,
LlamaHeadConverter,
LlamaHuggingfaceCheckpointHandler,
LlamaMLPConverter,
QueryWeightConverter,
get_weight_and_bias_converters,
)
from fast_llm.utils import Assert


class Qwen2AttentionConverter(LlamaAttentionConverter):
# TODO: Support sliding window with max_window_layers (need 2 kinds of block?)

@classmethod
def import_config(cls, config: dict) -> dict:
config["attention_bias"] = True
out = super().import_config(config)
out["query_layer"] = {"bias": {"enabled": True}}
out["key_layer"] = {"bias": {"enabled": True}}
out["value_layer"] = {"bias": {"enabled": True}}
out["dense_layer"] = {"bias": {"enabled": False}}
return out

@classmethod
def export_config(cls, config: AttentionConfig) -> dict:
out = super().export_config(config)
del out["attention_bias"]
return out

@classmethod
def _check_config(cls, config: AttentionConfig) -> None:
Assert.is_(type(config), AttentionConfig)
Expand All @@ -32,9 +54,56 @@ def _check_config(cls, config: AttentionConfig) -> None:
Assert.is_(config.value_layer.bias.enabled, True)
Assert.incl(config.dense_layer.bias.enabled, (None, False))

@classmethod
def get_converters(
cls,
config: AttentionConfig,
fast_llm_prefix: str,
hf_prefix: str,
drop_on_export: bool = False,
) -> list[WeightConverter]:
return [
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.query",
f"{hf_prefix}.q_proj",
True,
QueryWeightConverter,
config,
drop_on_export=drop_on_export,
),
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.key_value",
(f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"),
True,
KeyValueWeightConverter,
config,
drop_on_export=drop_on_export,
),
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.dense",
f"{hf_prefix}.o_proj",
False,
drop_on_export=drop_on_export,
),
]


class Qwen2MLPConverter(LlamaMLPConverter):
@classmethod
def import_config(cls, config: dict) -> dict:
config["mlp_bias"] = False
return super().import_config(config)

@classmethod
def export_config(cls, config: MLPConfig) -> dict:
out = super().export_config(config)
del out["mlp_bias"]
return out


class Qwen2BlockConverter(LlamaBlockConverter):
mixer_converter_class: typing.ClassVar[type[Qwen2AttentionConverter]] = Qwen2AttentionConverter
mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter


class Qwen2DecoderConverter(LlamaDecoderConverter):
Expand Down
66 changes: 40 additions & 26 deletions fast_llm/models/gpt/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import typing

import torch
import transformers.cache_utils
import transformers.modeling_outputs
import transformers.utils

from fast_llm.data.sample.language_model import LanguageModelBatch
from fast_llm.data.sample.token import TokenBatch
Expand Down Expand Up @@ -36,23 +38,23 @@ def inner_forward(
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
past_key_values=None,
past_key_values: transformers.cache_utils.Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
labels: torch.Tensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
cache_position: torch.Tensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: typing.Unpack[transformers.utils.TransformersKwargs],
) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast:
return self._inner_forward(
self._get_batch(input_ids, attention_mask, position_ids),
past_key_values,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
logits_to_keep,
**kwargs,
)

def _get_batch(
Expand Down Expand Up @@ -82,20 +84,26 @@ def _get_batch(
def _inner_forward(
self,
batch: LanguageModelBatch,
past_key_values=None,
past_key_values: transformers.cache_utils.Cache | None = None,
Copy link
Collaborator

@jlamypoirier jlamypoirier Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't currently follow the new cache format

inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
labels: torch.Tensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: list[str | re.Pattern] | bool | None = None,
return_dict: bool | None = None,
cache_position: torch.Tensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: typing.Unpack[transformers.utils.TransformersKwargs],
) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast:
# TODO: Most of this is generalizable.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions = (
kwargs["output_attentions"]
if "output_attentions" in kwargs and kwargs["output_attentions"] is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
kwargs["output_hidden_states"]
if "output_hidden_states" in kwargs and kwargs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = kwargs["return_dict"] if "return_dict" in kwargs else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache

if output_attentions:
Expand All @@ -104,6 +112,12 @@ def _inner_forward(
raise NotImplementedError()
if labels is not None:
raise NotImplementedError()
# TODO: seems cache_position are always provided even if use_cache is false
# check if it is the case and implement support to it.
# if cache_position is not None:
# raise NotImplementedError()
if isinstance(logits_to_keep, torch.Tensor) or logits_to_keep > 0:
raise NotImplementedError()

# Iteration serves as a random seed, using random module because it's not seeded by Fast LLM
iteration = random.randint(0, 2**32)
Expand All @@ -122,33 +136,33 @@ def _inner_forward(
# kwargs is shallow-copied so changes will propagate back to the main namespace.
kwargs_meta[BlockKwargs.output_hidden_states] = [re.compile(pattern) for pattern in output_hidden_states]

((input_, kwargs),) = self.fast_llm_base_model.preprocess_batch(
((input_, batch_kwargs),) = self.fast_llm_base_model.preprocess_batch(
batch, [(input_meta, kwargs_meta)], phase=PhaseType.inference, iteration=iteration
)

if past_key_values is not None:
# The transformers will use the past keys and values to this list.
kwargs[AttentionKwargs.past_key_values] = past_key_values
batch_kwargs[AttentionKwargs.past_key_values] = past_key_values
# TODO: preprocess needs to know about the past.
raise NotImplementedError()
if use_cache:
# The transformers will save the present keys and values to this list.
kwargs[AttentionKwargs.presents] = []
batch_kwargs[AttentionKwargs.presents] = []

kwargs["global_logits"] = True
batch_kwargs["global_logits"] = True

self._inference_runner.forward(input_, kwargs, iteration=iteration)
self._inference_runner.forward(input_, batch_kwargs, iteration=iteration)

# TODO: Make a proper way of returning the model output.
if kwargs[AttentionKwargs.sequence_first]:
logits = kwargs["logits"].transpose(0, 1)
if batch_kwargs[AttentionKwargs.sequence_first]:
logits = batch_kwargs["logits"].transpose(0, 1)
else:
logits = kwargs["logits"]
logits = batch_kwargs["logits"]

if output_hidden_states:
hidden_states = {
key: tensor if meta is None else meta.local_to_global(tensor)[0]
for key, (meta, tensor) in kwargs["hidden_states"].items()
for key, (meta, tensor) in batch_kwargs["hidden_states"].items()
}
else:
hidden_states = None
Expand All @@ -167,5 +181,5 @@ def _inner_forward(
return transformers.modeling_outputs.CausalLMOutputWithPast(
logits=logits,
hidden_states=hidden_states,
past_key_values=kwargs[AttentionKwargs.presents],
past_key_values=batch_kwargs[AttentionKwargs.presents],
)
18 changes: 10 additions & 8 deletions fast_llm/models/multimodal/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import typing

import torch
import transformers.cache_utils
import transformers.modeling_outputs
import transformers.utils

from fast_llm.data.preprocessing.image_patch import ImagePatchConfig
from fast_llm.data.sample.patch import PatchBatch
Expand Down Expand Up @@ -51,23 +53,23 @@ def inner_forward(
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
image_sizes: torch.Tensor | None = None,
past_key_values=None,
past_key_values: transformers.cache_utils.Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
labels: torch.Tensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
cache_position: torch.Tensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: typing.Unpack[transformers.utils.TransformersKwargs],
) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast:
return self._inner_forward(
self._get_batch(input_ids, pixel_values, attention_mask, position_ids, image_sizes),
past_key_values,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
logits_to_keep,
**kwargs,
)

def _get_batch(
Expand Down
24 changes: 17 additions & 7 deletions tests/models/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fast_llm.models.gpt.config import PretrainedGPTModelConfig
from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat
from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM
from tests.utils.distributed_configs import DistributedTestingConfig
from tests.utils.model_configs import ModelTestingGroup
from tests.utils.utils import requires_cuda

Expand Down Expand Up @@ -152,7 +153,9 @@ def _test_for_batches(
if tokenizer is not None:
inputs = _prepare_data(tokenizer, use_batch_size2=False)
else:
inputs = _prepare_rand_data(fast_llm_model.config.fast_llm_config.base_model.vocab_size, use_batch_size2=False)
inputs = _prepare_rand_data(
fast_llm_model.config.fast_llm_config.base_model.embeddings.vocab_size, use_batch_size2=False
)
outputs = _generate(
inputs,
hf_model,
Expand All @@ -164,7 +167,9 @@ def _test_for_batches(
if tokenizer is not None:
inputs = _prepare_data(tokenizer, use_batch_size2=True)
else:
inputs = _prepare_rand_data(fast_llm_model.config.fast_llm_config.base_model.vocab_size, use_batch_size2=True)
inputs = _prepare_rand_data(
fast_llm_model.config.fast_llm_config.base_model.embeddings.vocab_size, use_batch_size2=True
)
outputs = _generate(
inputs,
hf_model,
Expand Down Expand Up @@ -244,13 +249,19 @@ def test_export_for_generate(run_test_script_for_all_models, model_testing_confi
# Not really testing, anything, but handles dependencies more easily than a fixture.
if model_testing_config.checkpoint_format is None:
pytest.skip(f"Conversion not supported for {model_testing_config.name}")
run_test_script_for_all_models(
[
if torch.cuda.device_count() < 1:
pytest.skip(f"Not enough gpus to run the test")

distr_config = DistributedTestingConfig(
name="test_export_for_generate",
config_args=[
"training.train_iters=1",
f"training.export.format={model_testing_config.checkpoint_format.name}",
"training.export.interval=1",
],
num_gpus=1,
)
run_test_script_for_all_models(distr_config)


@pytest.mark.slow
Expand Down Expand Up @@ -340,7 +351,7 @@ def _test_forward_return_hidden_states(

inputs_ids = torch.randint(
1,
fast_llm_model.config.fast_llm_config.base_model.vocab_size if vocab_size is None else vocab_size,
fast_llm_model.config.fast_llm_config.base_model.embeddings.vocab_size if vocab_size is None else vocab_size,
[1, 10],
dtype=torch.int64,
generator=torch.Generator().manual_seed(42),
Expand All @@ -351,8 +362,7 @@ def _test_forward_return_hidden_states(
input_ids=inputs_ids, output_hidden_states=True, return_dict=True, use_cache=False
)

# hidden_states include embeddings layer
assert len(res_fast_llm.hidden_states) - 1 == len(fast_llm_model.config.fast_llm_config.base_model.decoder)
assert len(res_fast_llm.hidden_states) == fast_llm_model.config.fast_llm_config.base_model.decoder.num_blocks


@pytest.mark.extra_slow
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def do_get_lm_eval_config(base_path):

task_dir = pathlib.Path(lm_eval.tasks.__file__).parent.resolve()
return [
f"data.tokenizer.path={tokenizer_path}",
f"model.base_model.vocab_size=49157",
f"training.evaluators.evaluation_test.evaluator.tokenizer.path={tokenizer_path}",
f"model.base_model.embeddings.vocab_size=49157",
"training.evaluators.evaluation_test.interval=2",
"training.evaluators.evaluation_test.evaluator.type=lm_eval",
"training.evaluators.evaluation_test.evaluator.cli_args="
Expand Down
Loading