Skip to content

Commit b65e56b

Browse files
[Core] Refactor self.model() to call a helper for subclassing. (vllm-project#25084)
Signed-off-by: Patrick Toulme <ptoulme@meta.com> Signed-off-by: Patrick Toulme <pctoulme+1@gmail.com>
1 parent 49996cd commit b65e56b

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2268,6 +2268,38 @@ def synchronize_input_prep(self):
22682268
finally:
22692269
self.prepare_inputs_event.record()
22702270

2271+
def _model_forward(
2272+
self,
2273+
input_ids: Optional[torch.Tensor] = None,
2274+
positions: Optional[torch.Tensor] = None,
2275+
intermediate_tensors: Optional[IntermediateTensors] = None,
2276+
inputs_embeds: Optional[torch.Tensor] = None,
2277+
**model_kwargs: dict[str, Any],
2278+
) -> Any:
2279+
"""Helper method to call the model forward pass.
2280+
2281+
This method can be overridden by subclasses for model execution.
2282+
Motivation: We can inspect only this method versus
2283+
the whole execute_model, which has additional logic.
2284+
2285+
Args:
2286+
input_ids: Input token IDs
2287+
positions: Token positions
2288+
intermediate_tensors: Tensors from previous pipeline stages
2289+
inputs_embeds: Input embeddings (alternative to input_ids)
2290+
**model_kwargs: Additional model arguments
2291+
2292+
Returns:
2293+
Model output tensor
2294+
"""
2295+
return self.model(
2296+
input_ids=input_ids,
2297+
positions=positions,
2298+
intermediate_tensors=intermediate_tensors,
2299+
inputs_embeds=inputs_embeds,
2300+
**model_kwargs,
2301+
)
2302+
22712303
@torch.inference_mode()
22722304
def execute_model(
22732305
self,
@@ -2337,7 +2369,7 @@ def execute_model(
23372369
), record_function_or_nullcontext("Forward"),
23382370
self.maybe_get_kv_connector_output(scheduler_output) as
23392371
kv_connector_output):
2340-
model_output = self.model(
2372+
model_output = self._model_forward(
23412373
input_ids=input_ids,
23422374
positions=positions,
23432375
intermediate_tensors=intermediate_tensors,

0 commit comments

Comments
 (0)