Skip to content

Commit fe6b19c

Browse files
[Bugfix] Properly abort pooling request. (vllm-project#25734)
Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
1 parent 2827b3f commit fe6b19c

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

tests/v1/engine/test_output_processor.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
STOP_STRINGS,
1313
DummyOutputProcessorTestVectors,
1414
MockEngineCore)
15+
from vllm import PoolingParams
1516
from vllm.logprobs import PromptLogprobs, SampleLogprobs
1617
from vllm.outputs import CompletionOutput, RequestOutput
1718
from vllm.sampling_params import RequestOutputKind, SamplingParams
@@ -998,3 +999,35 @@ async def test_cumulative_output_collector_n():
998999
third = [k for k in result.outputs if k.index == 2]
9991000
assert len(third) == 1
10001001
assert third[0].text == "c"
1002+
1003+
1004+
@pytest.mark.parametrize("runner", ["generate", "pooling"])
1005+
def test_abort_requests(runner: str, dummy_test_vectors):
1006+
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
1007+
log_stats=True)
1008+
requests = [
1009+
EngineCoreRequest(
1010+
request_id=f"request-{idx}",
1011+
prompt_token_ids=prompt_tokens,
1012+
mm_features=None,
1013+
eos_token_id=None,
1014+
arrival_time=0,
1015+
lora_request=None,
1016+
cache_salt=None,
1017+
data_parallel_rank=None,
1018+
sampling_params=SamplingParams() if runner == "generate" else None,
1019+
pooling_params=PoolingParams(
1020+
task="embed") if runner == "pooling" else None,
1021+
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
1022+
]
1023+
1024+
for request in requests:
1025+
if runner == "generate":
1026+
output_kind = request.sampling_params.output_kind
1027+
else:
1028+
output_kind = request.pooling_params.output_kind
1029+
queue = RequestOutputCollector(output_kind=output_kind)
1030+
output_processor.add_request(request, None, queue=queue)
1031+
1032+
for request in requests:
1033+
output_processor.abort_requests([request.request_id])

vllm/v1/engine/output_processor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,14 @@ def abort_requests(
335335
# Produce final abort output.
336336
if req_state.queue is not None and (
337337
request_output := req_state.make_request_output(
338-
[], None, FinishReason.ABORT, None, None)):
338+
new_token_ids=[],
339+
# Set pooling_output is not None to
340+
# correctly enter the abort pooling branch
341+
pooling_output=torch.randn(0, device="cpu")
342+
if req_state.detokenizer is None else None,
343+
finish_reason=FinishReason.ABORT,
344+
stop_reason=None,
345+
kv_transfer_params=None)):
339346
req_state.queue.put(request_output)
340347
elif parent := self.parent_requests.get(request_id):
341348
# Abort children prior to removing the parent.

0 commit comments

Comments
 (0)