Skip to content

Commit c1c005e

Browse files
committed
revert tracing context manager changes to core.py
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 363aaee commit c1c005e

File tree

1 file changed

+44
-73
lines changed

1 file changed

+44
-73
lines changed

vllm/v1/engine/core.py

Lines changed: 44 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
from vllm.v1.request import Request, RequestStatus
6464
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
6565
from vllm.v1.structured_output import StructuredOutputManager
66-
from vllm.v1.utils import record_function_or_nullcontext
6766
from vllm.version import __version__ as VLLM_VERSION
6867

6968
logger = init_logger(__name__)
@@ -322,21 +321,17 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
322321
# or finished and not yet removed from the batch.
323322
if not self.scheduler.has_requests():
324323
return {}, False
325-
with record_function_or_nullcontext("core step: schedule"):
326-
scheduler_output = self.scheduler.schedule()
327-
328-
with record_function_or_nullcontext("core step: execute_model"):
329-
future = self.model_executor.execute_model(scheduler_output, non_block=True)
330-
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
331-
with self.log_error_detail(scheduler_output):
332-
model_output = future.result()
333-
if model_output is None:
334-
model_output = self.model_executor.sample_tokens(grammar_output)
335-
336-
with record_function_or_nullcontext("core step: update_from_output"):
337-
engine_core_outputs = self.scheduler.update_from_output(
338-
scheduler_output, model_output
339-
)
324+
scheduler_output = self.scheduler.schedule()
325+
future = self.model_executor.execute_model(scheduler_output, non_block=True)
326+
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
327+
with self.log_error_detail(scheduler_output):
328+
model_output = future.result()
329+
if model_output is None:
330+
model_output = self.model_executor.sample_tokens(grammar_output)
331+
332+
engine_core_outputs = self.scheduler.update_from_output(
333+
scheduler_output, model_output
334+
)
340335

341336
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
342337

@@ -374,49 +369,32 @@ def step_with_batch_queue(
374369
model_executed = False
375370
deferred_scheduler_output = None
376371
if self.scheduler.has_requests():
377-
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
378-
scheduler_output = self.scheduler.schedule()
379-
with record_function_or_nullcontext(
380-
"core step_with_batch_queue: execute_model"
381-
):
382-
exec_future = self.model_executor.execute_model(
383-
scheduler_output, non_block=True
384-
)
372+
scheduler_output = self.scheduler.schedule()
373+
exec_future = self.model_executor.execute_model(
374+
scheduler_output, non_block=True
375+
)
385376
model_executed = scheduler_output.total_num_scheduled_tokens > 0
386377

387378
if scheduler_output.pending_structured_output_tokens:
388-
with record_function_or_nullcontext(
389-
"core step_with_batch_queue: pending_structured_output_tokens"
390-
):
391-
# We need to defer sampling until we have processed the model output
392-
# from the prior step.
393-
deferred_scheduler_output = scheduler_output
394-
# Block-wait for execute to return
395-
# (continues running async on the GPU).
396-
with self.log_error_detail(scheduler_output):
397-
exec_result = exec_future.result()
398-
assert exec_result is None
379+
# We need to defer sampling until we have processed the model output
380+
# from the prior step.
381+
deferred_scheduler_output = scheduler_output
382+
# Block-wait for execute to return (continues running async on the GPU).
383+
with self.log_error_detail(scheduler_output):
384+
exec_result = exec_future.result()
385+
assert exec_result is None
399386
else:
400-
with record_function_or_nullcontext(
401-
"core step_with_batch_queue: get_grammar_bitmask"
402-
):
403-
# We aren't waiting for any tokens, get any grammar
404-
# output immediately.
405-
grammar_output = self.scheduler.get_grammar_bitmask(
406-
scheduler_output
407-
)
387+
# We aren't waiting for any tokens, get any grammar output immediately.
388+
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
408389
# Block-wait for execute to return (continues running async on the GPU).
409390
with self.log_error_detail(scheduler_output):
410391
exec_result = exec_future.result()
411392

412393
if exec_result is None:
413-
with record_function_or_nullcontext(
414-
"core step_with_batch_queue: sample_tokens"
415-
):
416-
# Call sample tokens.
417-
future = self.model_executor.sample_tokens(
418-
grammar_output, non_block=True
419-
)
394+
# Call sample tokens.
395+
future = self.model_executor.sample_tokens(
396+
grammar_output, non_block=True
397+
)
420398
else:
421399
# No sampling required (e.g. all requests finished).
422400
future = cast(Future[ModelRunnerOutput], exec_future)
@@ -436,34 +414,27 @@ def step_with_batch_queue(
436414
# only be called when the scheduler contains requests or the queue
437415
# is non-empty.
438416
return None, False
439-
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
440-
# Block until the next result is available.
441-
future, scheduler_output = batch_queue.pop()
442-
with self.log_error_detail(scheduler_output):
443-
model_output = future.result()
444-
with record_function_or_nullcontext(
445-
"core step_with_batch_queue: update_from_output"
446-
):
447-
engine_core_outputs = self.scheduler.update_from_output(
448-
scheduler_output, model_output
449-
)
417+
418+
# Block until the next result is available.
419+
future, scheduler_output = batch_queue.pop()
420+
with self.log_error_detail(scheduler_output):
421+
model_output = future.result()
422+
423+
engine_core_outputs = self.scheduler.update_from_output(
424+
scheduler_output, model_output
425+
)
450426

451427
# NOTE(nick): We can either handle the deferred tasks here or save
452428
# in a field and do it immediately once step_with_batch_queue is
453429
# re-called. The latter slightly favors TTFT over TPOT/throughput.
454430
if deferred_scheduler_output:
455-
with record_function_or_nullcontext(
456-
"core step_with_batch_queue: deferred_scheduler_output"
457-
):
458-
# We now have the tokens needed to compute the bitmask for the
459-
# deferred request. Get the bitmask and call sample tokens.
460-
grammar_output = self.scheduler.get_grammar_bitmask(
461-
deferred_scheduler_output
462-
)
463-
future = self.model_executor.sample_tokens(
464-
grammar_output, non_block=True
465-
)
466-
batch_queue.appendleft((future, deferred_scheduler_output))
431+
# We now have the tokens needed to compute the bitmask for the
432+
# deferred request. Get the bitmask and call sample tokens.
433+
grammar_output = self.scheduler.get_grammar_bitmask(
434+
deferred_scheduler_output
435+
)
436+
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
437+
batch_queue.appendleft((future, deferred_scheduler_output))
467438

468439
return engine_core_outputs, model_executed
469440

0 commit comments

Comments
 (0)