|
12 | 12 | STOP_STRINGS, |
13 | 13 | DummyOutputProcessorTestVectors, |
14 | 14 | MockEngineCore) |
| 15 | +from vllm import PoolingParams |
15 | 16 | from vllm.logprobs import PromptLogprobs, SampleLogprobs |
16 | 17 | from vllm.outputs import CompletionOutput, RequestOutput |
17 | 18 | from vllm.sampling_params import RequestOutputKind, SamplingParams |
@@ -998,3 +999,35 @@ async def test_cumulative_output_collector_n(): |
998 | 999 | third = [k for k in result.outputs if k.index == 2] |
999 | 1000 | assert len(third) == 1 |
1000 | 1001 | assert third[0].text == "c" |
| 1002 | + |
| 1003 | + |
| 1004 | +@pytest.mark.parametrize("runner", ["generate", "pooling"]) |
| 1005 | +def test_abort_requests(runner: str, dummy_test_vectors): |
| 1006 | + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, |
| 1007 | + log_stats=True) |
| 1008 | + requests = [ |
| 1009 | + EngineCoreRequest( |
| 1010 | + request_id=f"request-{idx}", |
| 1011 | + prompt_token_ids=prompt_tokens, |
| 1012 | + mm_features=None, |
| 1013 | + eos_token_id=None, |
| 1014 | + arrival_time=0, |
| 1015 | + lora_request=None, |
| 1016 | + cache_salt=None, |
| 1017 | + data_parallel_rank=None, |
| 1018 | + sampling_params=SamplingParams() if runner == "generate" else None, |
| 1019 | + pooling_params=PoolingParams( |
| 1020 | + task="embed") if runner == "pooling" else None, |
| 1021 | + ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) |
| 1022 | + ] |
| 1023 | + |
| 1024 | + for request in requests: |
| 1025 | + if runner == "generate": |
| 1026 | + output_kind = request.sampling_params.output_kind |
| 1027 | + else: |
| 1028 | + output_kind = request.pooling_params.output_kind |
| 1029 | + queue = RequestOutputCollector(output_kind=output_kind) |
| 1030 | + output_processor.add_request(request, None, queue=queue) |
| 1031 | + |
| 1032 | + for request in requests: |
| 1033 | + output_processor.abort_requests([request.request_id]) |
0 commit comments