@@ -2268,6 +2268,38 @@ def synchronize_input_prep(self):
22682268 finally :
22692269 self .prepare_inputs_event .record ()
22702270
2271+ def _model_forward (
2272+ self ,
2273+ input_ids : Optional [torch .Tensor ] = None ,
2274+ positions : Optional [torch .Tensor ] = None ,
2275+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
2276+ inputs_embeds : Optional [torch .Tensor ] = None ,
2277+ ** model_kwargs : dict [str , Any ],
2278+ ) -> Any :
2279+ """Helper method to call the model forward pass.
2280+
2281+ This method can be overridden by subclasses for model execution.
2282+ Motivation: We can inspect only this method versus
2283+ the whole execute_model, which has additional logic.
2284+
2285+ Args:
2286+ input_ids: Input token IDs
2287+ positions: Token positions
2288+ intermediate_tensors: Tensors from previous pipeline stages
2289+ inputs_embeds: Input embeddings (alternative to input_ids)
2290+ **model_kwargs: Additional model arguments
2291+
2292+ Returns:
2293+ Model output tensor
2294+ """
2295+ return self .model (
2296+ input_ids = input_ids ,
2297+ positions = positions ,
2298+ intermediate_tensors = intermediate_tensors ,
2299+ inputs_embeds = inputs_embeds ,
2300+ ** model_kwargs ,
2301+ )
2302+
22712303 @torch .inference_mode ()
22722304 def execute_model (
22732305 self ,
@@ -2337,7 +2369,7 @@ def execute_model(
23372369 ), record_function_or_nullcontext ("Forward" ),
23382370 self .maybe_get_kv_connector_output (scheduler_output ) as
23392371 kv_connector_output ):
2340- model_output = self .model (
2372+ model_output = self ._model_forward (
23412373 input_ids = input_ids ,
23422374 positions = positions ,
23432375 intermediate_tensors = intermediate_tensors ,
0 commit comments