From f93708dfa31ca1791218812ac4069134527e6dd6 Mon Sep 17 00:00:00 2001 From: liugddx Date: Thu, 6 Nov 2025 23:24:04 +0800 Subject: [PATCH 1/2] feat: implement error handling strategies for streaming chat completions Signed-off-by: liugddx --- models/spring-ai-openai/pom.xml | 6 + .../api/ChatCompletionParseException.java | 49 +++ .../ai/openai/api/OpenAiApi.java | 76 ++++- .../api/StreamErrorHandlingStrategy.java | 49 +++ .../api/OpenAiStreamErrorHandlingTest.java | 289 ++++++++++++++++++ 5 files changed, 463 insertions(+), 6 deletions(-) create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ChatCompletionParseException.java create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/StreamErrorHandlingStrategy.java create mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamErrorHandlingTest.java diff --git a/models/spring-ai-openai/pom.xml b/models/spring-ai-openai/pom.xml index 3e534615f2b..9de0f58d348 100644 --- a/models/spring-ai-openai/pom.xml +++ b/models/spring-ai-openai/pom.xml @@ -117,6 +117,12 @@ test + + io.projectreactor + reactor-test + test + + diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ChatCompletionParseException.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ChatCompletionParseException.java new file mode 100644 index 00000000000..a081229cc4e --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ChatCompletionParseException.java @@ -0,0 +1,49 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.api; + +/** + * Exception thrown when a ChatCompletionChunk cannot be parsed from streaming response. + * This typically occurs when the LLM returns malformed JSON. + * + * @author Liu Guodong + * @since 1.0.0 + */ +public class ChatCompletionParseException extends RuntimeException { + + private final String rawContent; + + /** + * Constructs a new ChatCompletionParseException. + * @param message the detail message + * @param rawContent the raw content that failed to parse + * @param cause the cause of the parsing failure + */ + public ChatCompletionParseException(String message, String rawContent, Throwable cause) { + super(message, cause); + this.rawContent = rawContent; + } + + /** + * Returns the raw content that failed to parse. + * @return the raw content string + */ + public String getRawContent() { + return this.rawContent; + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 070d4e4b5c6..23b305f58a0 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -70,6 +70,8 @@ */ public class OpenAiApi { + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(OpenAiApi.class); + public static final String HTTP_USER_AGENT_HEADER = "User-Agent"; public static final String SPRING_AI_USER_AGENT = "spring-ai"; @@ -116,6 +118,8 @@ public static Builder builder() { private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper(); + private StreamErrorHandlingStrategy streamErrorHandlingStrategy = StreamErrorHandlingStrategy.SKIP; + /** * Create a new chat completion api. * @param baseUrl api base URL. @@ -245,7 +249,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat .headers(headers -> { headers.addAll(additionalHttpHeader); addDefaultHeadersIfMissing(headers); - }) // @formatter:on + }) .body(Mono.just(chatRequest), ChatCompletionRequest.class) .retrieve() .bodyToFlux(String.class) @@ -253,8 +257,17 @@ public Flux chatCompletionStream(ChatCompletionRequest chat .takeUntil(SSE_DONE_PREDICATE) // filters out the "[DONE]" message. .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - // Detect is the chunk is part of a streaming function call. + // Parse JSON string to ChatCompletionChunk with error handling + .flatMap(content -> { + try { + ChatCompletionChunk chunk = ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class); + return Mono.just(chunk); + } + catch (Exception e) { + return handleParseError(content, e); + } + }) + // Detect if the chunk is part of a streaming function call. .map(chunk -> { if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { isInsideTool.set(true); @@ -276,12 +289,52 @@ public Flux chatCompletionStream(ChatCompletionRequest chat // Flux> -> Flux> .concatMapIterable(window -> { Mono monoChunk = window.reduce( - new ChatCompletionChunk(null, null, null, null, null, null, null, null), (previous, current) -> this.chunkMerger.merge(previous, current)); return List.of(monoChunk); }) // Flux> -> Flux .flatMap(mono -> mono); + // @formatter:on + } + + /** + * Handles parsing errors when processing streaming chat completion chunks. The + * behavior depends on the configured {@link StreamErrorHandlingStrategy}. + * @param content the raw content that failed to parse + * @param e the exception that occurred during parsing + * @return a Mono that either emits nothing (skip), emits an error, or logs and + * continues + */ + private Mono handleParseError(String content, Exception e) { + String errorMessage = String.format( + "Failed to parse ChatCompletionChunk from streaming response. " + + "Raw content: [%s]. This may indicate malformed JSON from the LLM. Error: %s", + content, e.getMessage()); + + switch (this.streamErrorHandlingStrategy) { + case FAIL_FAST: + logger.error(errorMessage, e); + return Mono.error(new ChatCompletionParseException("Invalid JSON chunk received from LLM", content, e)); + + case LOG_AND_CONTINUE: + logger.warn(errorMessage); + logger.debug("Full stack trace for JSON parsing error:", e); + return Mono.empty(); + + case SKIP: + default: + logger.warn("Skipping invalid chunk in streaming response. Raw content: [{}]. Error: {}", content, + e.getMessage()); + return Mono.empty(); + } + } + + /** + * Sets the error handling strategy for streaming chat completion parsing errors. + * @param strategy the strategy to use when encountering JSON parsing errors + */ + public void setStreamErrorHandlingStrategy(StreamErrorHandlingStrategy strategy) { + this.streamErrorHandlingStrategy = strategy != null ? strategy : StreamErrorHandlingStrategy.SKIP; } /** @@ -2006,6 +2059,7 @@ public Builder(OpenAiApi api) { this.restClientBuilder = api.restClient != null ? api.restClient.mutate() : RestClient.builder(); this.webClientBuilder = api.webClient != null ? api.webClient.mutate() : WebClient.builder(); this.responseErrorHandler = api.getResponseErrorHandler(); + this.streamErrorHandlingStrategy = api.streamErrorHandlingStrategy; } private String baseUrl = OpenAiApiConstants.DEFAULT_BASE_URL; @@ -2024,6 +2078,8 @@ public Builder(OpenAiApi api) { private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER; + private StreamErrorHandlingStrategy streamErrorHandlingStrategy = StreamErrorHandlingStrategy.SKIP; + public Builder baseUrl(String baseUrl) { Assert.hasText(baseUrl, "baseUrl cannot be null or empty"); this.baseUrl = baseUrl; @@ -2077,10 +2133,18 @@ public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) { return this; } + public Builder streamErrorHandlingStrategy(StreamErrorHandlingStrategy streamErrorHandlingStrategy) { + Assert.notNull(streamErrorHandlingStrategy, "streamErrorHandlingStrategy cannot be null"); + this.streamErrorHandlingStrategy = streamErrorHandlingStrategy; + return this; + } + public OpenAiApi build() { Assert.notNull(this.apiKey, "apiKey must be set"); - return new OpenAiApi(this.baseUrl, this.apiKey, this.headers, this.completionsPath, this.embeddingsPath, - this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler); + OpenAiApi api = new OpenAiApi(this.baseUrl, this.apiKey, this.headers, this.completionsPath, + this.embeddingsPath, this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler); + api.setStreamErrorHandlingStrategy(this.streamErrorHandlingStrategy); + return api; } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/StreamErrorHandlingStrategy.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/StreamErrorHandlingStrategy.java new file mode 100644 index 00000000000..3b15ea488b0 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/StreamErrorHandlingStrategy.java @@ -0,0 +1,49 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.api; + +/** + * Strategy for handling JSON parsing errors in streaming chat completions. This is + * particularly useful when dealing with LLMs that may return malformed JSON, such as + * Qwen3-8B or other custom models. + * + * @author Liu Guodong + * @since 1.0.0 + */ +public enum StreamErrorHandlingStrategy { + + /** + * Skip invalid chunks and continue processing the stream. This is the default and + * recommended strategy for production use. Invalid chunks are logged but do not + * interrupt the stream. + */ + SKIP, + + /** + * Fail immediately when encountering an invalid chunk. The error is propagated + * through the reactive stream, terminating the stream processing. + */ + FAIL_FAST, + + /** + * Log the error and continue processing. Similar to SKIP but with more detailed + * logging. Use this for debugging or when you want to monitor the frequency of + * parsing errors. + */ + LOG_AND_CONTINUE + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamErrorHandlingTest.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamErrorHandlingTest.java new file mode 100644 index 00000000000..1a4971d47e4 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamErrorHandlingTest.java @@ -0,0 +1,289 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.api; + +import java.util.List; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for streaming chat completion error handling in {@link OpenAiApi}. Tests the + * behavior when the LLM returns malformed JSON chunks. + * + * @author Liu Guodong + */ +class OpenAiStreamErrorHandlingTest { + + private MockWebServer mockWebServer; + + private OpenAiApi openAiApi; + + @BeforeEach + void setUp() throws Exception { + this.mockWebServer = new MockWebServer(); + this.mockWebServer.start(); + } + + @AfterEach + void tearDown() throws Exception { + if (this.mockWebServer != null) { + this.mockWebServer.shutdown(); + } + } + + @Test + void testSkipStrategy_shouldSkipInvalidChunks() { + // Arrange + String validChunk1 = """ + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]} + """; + String invalidChunk = "invalid json {"; + String validChunk2 = """ + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]} + """; + + this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) + .setBody(validChunk1 + invalidChunk + validChunk2 + "[DONE]")); + + this.openAiApi = OpenAiApi.builder() + .apiKey("test-key") + .baseUrl(this.mockWebServer.url("/").toString()) + .streamErrorHandlingStrategy(StreamErrorHandlingStrategy.SKIP) + .build(); + + ChatCompletionMessage message = new ChatCompletionMessage("Test", ChatCompletionMessage.Role.USER); + ChatCompletionRequest request = new ChatCompletionRequest(List.of(message), "gpt-3.5-turbo", 0.8, true); + + // Act + Flux result = this.openAiApi.chatCompletionStream(request); + + // Assert - should receive 2 valid chunks, invalid one is skipped + StepVerifier.create(result) + .expectNextMatches(chunk -> chunk.choices() != null && chunk.choices().size() > 0) + .expectNextMatches(chunk -> chunk.choices() != null && chunk.choices().size() > 0) + .verifyComplete(); + } + + @Test + void testFailFastStrategy_shouldThrowErrorOnInvalidChunk() { + // Arrange + String validChunk = """ + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]} + """; + String invalidChunk = "invalid json {"; + + this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) + .setBody(validChunk + invalidChunk + "[DONE]")); + + this.openAiApi = OpenAiApi.builder() + .apiKey("test-key") + .baseUrl(this.mockWebServer.url("/").toString()) + .streamErrorHandlingStrategy(StreamErrorHandlingStrategy.FAIL_FAST) + .build(); + + ChatCompletionMessage message = new ChatCompletionMessage("Test", ChatCompletionMessage.Role.USER); + ChatCompletionRequest request = new ChatCompletionRequest(List.of(message), "gpt-3.5-turbo", 0.8, true); + + // Act + Flux result = this.openAiApi.chatCompletionStream(request); + + // Assert - should receive 1 valid chunk then error + StepVerifier.create(result) + .expectNextMatches(chunk -> chunk.choices() != null && chunk.choices().size() > 0) + .expectError(ChatCompletionParseException.class) + .verify(); + } + + @Test + void testLogAndContinueStrategy_shouldLogAndSkipInvalidChunks() { + // Arrange + String validChunk1 = """ + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]} + """; + String invalidChunk = "{incomplete"; + String validChunk2 = """ + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":"stop"}]} + """; + + this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) + .setBody(validChunk1 + invalidChunk + validChunk2 + "[DONE]")); + + this.openAiApi = OpenAiApi.builder() + .apiKey("test-key") + .baseUrl(this.mockWebServer.url("/").toString()) + .streamErrorHandlingStrategy(StreamErrorHandlingStrategy.LOG_AND_CONTINUE) + .build(); + + ChatCompletionMessage message = new ChatCompletionMessage("Test", ChatCompletionMessage.Role.USER); + ChatCompletionRequest request = new ChatCompletionRequest(List.of(message), "gpt-3.5-turbo", 0.8, true); + + // Act + Flux result = this.openAiApi.chatCompletionStream(request); + + // Assert - should receive 2 valid chunks, invalid one is logged and skipped + StepVerifier.create(result).expectNextCount(2).verifyComplete(); + } + + @Test + void testDefaultStrategy_shouldBeSkip() { + // Arrange + this.openAiApi = OpenAiApi.builder().apiKey("test-key").baseUrl(this.mockWebServer.url("/").toString()).build(); + + String validChunk = """ + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]} + """; + String invalidChunk = "not valid json"; + + this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) + .setBody(validChunk + invalidChunk + "[DONE]")); + + ChatCompletionMessage message = new ChatCompletionMessage("Test", ChatCompletionMessage.Role.USER); + ChatCompletionRequest request = new ChatCompletionRequest(List.of(message), "gpt-3.5-turbo", 0.8, true); + + // Act + Flux result = this.openAiApi.chatCompletionStream(request); + + // Assert - default strategy should skip invalid chunks + StepVerifier.create(result).expectNextCount(1).verifyComplete(); + } + + @Test + void testAllValidChunks_shouldProcessNormally() { + // Arrange + String validChunk1 = """ + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]} + """; + String validChunk2 = """ + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":"stop"}]} + """; + + this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) + .setBody(validChunk1 + validChunk2 + "[DONE]")); + + this.openAiApi = OpenAiApi.builder() + .apiKey("test-key") + .baseUrl(this.mockWebServer.url("/").toString()) + .streamErrorHandlingStrategy(StreamErrorHandlingStrategy.SKIP) + .build(); + + ChatCompletionMessage message = new ChatCompletionMessage("Test", ChatCompletionMessage.Role.USER); + ChatCompletionRequest request = new ChatCompletionRequest(List.of(message), "gpt-3.5-turbo", 0.8, true); + + // Act + Flux result = this.openAiApi.chatCompletionStream(request); + + // Assert + StepVerifier.create(result).expectNextCount(2).verifyComplete(); + } + + @Test + void testChatCompletionParseException_shouldContainRawContent() { + // Arrange + String rawContent = "invalid json content"; + Exception cause = new RuntimeException("Parse error"); + + // Act + ChatCompletionParseException exception = new ChatCompletionParseException("Test error", rawContent, cause); + + // Assert + assertThat(exception.getRawContent()).isEqualTo(rawContent); + assertThat(exception.getMessage()).isEqualTo("Test error"); + assertThat(exception.getCause()).isEqualTo(cause); + } + + @Test + void testMutateApi_shouldPreserveErrorHandlingStrategy() { + // Arrange + OpenAiApi originalApi = OpenAiApi.builder() + .apiKey("test-key") + .baseUrl(this.mockWebServer.url("/").toString()) + .streamErrorHandlingStrategy(StreamErrorHandlingStrategy.FAIL_FAST) + .build(); + + // Act + OpenAiApi mutatedApi = originalApi.mutate().apiKey("new-key").build(); + + String validChunk = """ + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]} + """; + String invalidChunk = "invalid"; + + this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) + .setBody(validChunk + invalidChunk + "[DONE]")); + + ChatCompletionMessage message = new ChatCompletionMessage("Test", ChatCompletionMessage.Role.USER); + ChatCompletionRequest request = new ChatCompletionRequest(List.of(message), "gpt-3.5-turbo", 0.8, true); + + // Act + Flux result = mutatedApi.chatCompletionStream(request); + + // Assert - should still use FAIL_FAST strategy + StepVerifier.create(result).expectNextCount(1).expectError(ChatCompletionParseException.class).verify(); + } + + @Test + void testSetStreamErrorHandlingStrategy_shouldUpdateStrategy() { + // Arrange + this.openAiApi = OpenAiApi.builder() + .apiKey("test-key") + .baseUrl(this.mockWebServer.url("/").toString()) + .streamErrorHandlingStrategy(StreamErrorHandlingStrategy.SKIP) + .build(); + + // Change strategy to FAIL_FAST + this.openAiApi.setStreamErrorHandlingStrategy(StreamErrorHandlingStrategy.FAIL_FAST); + + String validChunk = """ + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]} + """; + String invalidChunk = "bad json"; + + this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) + .setBody(validChunk + invalidChunk + "[DONE]")); + + ChatCompletionMessage message = new ChatCompletionMessage("Test", ChatCompletionMessage.Role.USER); + ChatCompletionRequest request = new ChatCompletionRequest(List.of(message), "gpt-3.5-turbo", 0.8, true); + + // Act + Flux result = this.openAiApi.chatCompletionStream(request); + + // Assert - should now fail fast + StepVerifier.create(result).expectNextCount(1).expectError(ChatCompletionParseException.class).verify(); + } + +} From 32b536395952f930302d766752987e57002ed737 Mon Sep 17 00:00:00 2001 From: liugddx Date: Fri, 7 Nov 2025 22:13:53 +0800 Subject: [PATCH 2/2] feat: enhance OpenAiApi to handle multi-line responses and improve test cases Signed-off-by: liugddx --- .../springframework/ai/openai/api/OpenAiApi.java | 4 ++++ .../openai/api/OpenAiStreamErrorHandlingTest.java | 14 +++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 23b305f58a0..1fe5570deb6 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -253,6 +253,10 @@ public Flux chatCompletionStream(ChatCompletionRequest chat .body(Mono.just(chatRequest), ChatCompletionRequest.class) .retrieve() .bodyToFlux(String.class) + // Split by newlines to handle multi-line responses (common in tests with MockWebServer) + .flatMap(content -> Flux.fromArray(content.split("\\r?\\n"))) + // Filter out empty lines + .filter(line -> !line.trim().isEmpty()) // cancels the flux stream after the "[DONE]" is received. .takeUntil(SSE_DONE_PREDICATE) // filters out the "[DONE]" message. diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamErrorHandlingTest.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamErrorHandlingTest.java index 1a4971d47e4..34a95ac0167 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamErrorHandlingTest.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamErrorHandlingTest.java @@ -72,7 +72,7 @@ void testSkipStrategy_shouldSkipInvalidChunks() { this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) - .setBody(validChunk1 + invalidChunk + validChunk2 + "[DONE]")); + .setBody(validChunk1.trim() + "\n" + invalidChunk + "\n" + validChunk2.trim() + "\n[DONE]")); this.openAiApi = OpenAiApi.builder() .apiKey("test-key") @@ -103,7 +103,7 @@ void testFailFastStrategy_shouldThrowErrorOnInvalidChunk() { this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) - .setBody(validChunk + invalidChunk + "[DONE]")); + .setBody(validChunk.trim() + "\n" + invalidChunk + "\n[DONE]")); this.openAiApi = OpenAiApi.builder() .apiKey("test-key") @@ -137,7 +137,7 @@ void testLogAndContinueStrategy_shouldLogAndSkipInvalidChunks() { this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) - .setBody(validChunk1 + invalidChunk + validChunk2 + "[DONE]")); + .setBody(validChunk1.trim() + "\n" + invalidChunk + "\n" + validChunk2.trim() + "\n[DONE]")); this.openAiApi = OpenAiApi.builder() .apiKey("test-key") @@ -167,7 +167,7 @@ void testDefaultStrategy_shouldBeSkip() { this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) - .setBody(validChunk + invalidChunk + "[DONE]")); + .setBody(validChunk.trim() + "\n" + invalidChunk + "\n[DONE]")); ChatCompletionMessage message = new ChatCompletionMessage("Test", ChatCompletionMessage.Role.USER); ChatCompletionRequest request = new ChatCompletionRequest(List.of(message), "gpt-3.5-turbo", 0.8, true); @@ -191,7 +191,7 @@ void testAllValidChunks_shouldProcessNormally() { this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) - .setBody(validChunk1 + validChunk2 + "[DONE]")); + .setBody(validChunk1.trim() + "\n" + validChunk2.trim() + "\n[DONE]")); this.openAiApi = OpenAiApi.builder() .apiKey("test-key") @@ -243,7 +243,7 @@ void testMutateApi_shouldPreserveErrorHandlingStrategy() { this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) - .setBody(validChunk + invalidChunk + "[DONE]")); + .setBody(validChunk.trim() + "\n" + invalidChunk + "\n[DONE]")); ChatCompletionMessage message = new ChatCompletionMessage("Test", ChatCompletionMessage.Role.USER); ChatCompletionRequest request = new ChatCompletionRequest(List.of(message), "gpt-3.5-turbo", 0.8, true); @@ -274,7 +274,7 @@ void testSetStreamErrorHandlingStrategy_shouldUpdateStrategy() { this.mockWebServer.enqueue(new MockResponse().setResponseCode(200) .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) - .setBody(validChunk + invalidChunk + "[DONE]")); + .setBody(validChunk.trim() + "\n" + invalidChunk + "\n[DONE]")); ChatCompletionMessage message = new ChatCompletionMessage("Test", ChatCompletionMessage.Role.USER); ChatCompletionRequest request = new ChatCompletionRequest(List.of(message), "gpt-3.5-turbo", 0.8, true);