From d8fdd026bf1b18fb6fa6305b180dfbea97020e3c Mon Sep 17 00:00:00 2001 From: KimMinjeong Date: Thu, 8 Aug 2024 01:00:48 +0900 Subject: [PATCH 1/2] feat: add responseMimeType option in vertexAiGeminiChatOptions --- .../gemini/VertexAiGeminiChatModel.java | 3 ++ .../gemini/VertexAiGeminiChatOptions.java | 34 +++++++++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index e02c4ad3427..575da73da90 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -359,6 +359,9 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) { if (options.getStopSequences() != null) { generationConfigBuilder.addAllStopSequences(options.getStopSequences()); } + if (options.getResponseMimeType() != null) { + generationConfigBuilder.setResponseMimeType(options.getResponseMimeType()); + } return generationConfigBuilder.build(); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index ec42584824a..9bb2c9cfb83 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -76,6 +76,12 @@ public enum TransportType { * Gemini model name. */ private @JsonProperty("modelName") String model; + /** + * Optional. Output response mimetype of the generated candidate text. + * - text/plain: (default) Text output. + * - application/json: JSON response in the candidates. + */ + private @JsonProperty("responseMimeType") String responseMimeType; /** * Tool Function Callbacks to register with the ChatModel. @@ -150,6 +156,12 @@ public Builder withModel(ChatModel model) { return this; } + public Builder withResponseMimeType(String mimeType){ + Assert.notNull(mimeType, "mimeType must not be null"); + this.options.setResponseMimeType(mimeType); + return this; + } + public Builder withFunctionCallbacks(List functionCallbacks) { this.options.functionCallbacks = functionCallbacks; return this; @@ -238,6 +250,14 @@ public void setModel(String modelName) { this.model = modelName; } + public String getResponseMimeType() { + return this.responseMimeType; + } + + public String setResponseMimeType(String mimeType) { + return this.responseMimeType = mimeType; + } + public List getFunctionCallbacks() { return this.functionCallbacks; } @@ -265,6 +285,7 @@ public int hashCode() { result = prime * result + ((candidateCount == null) ? 0 : candidateCount.hashCode()); result = prime * result + ((maxOutputTokens == null) ? 0 : maxOutputTokens.hashCode()); result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((responseMimeType == null) ? 0 : responseMimeType.hashCode()); result = prime * result + ((functionCallbacks == null) ? 0 : functionCallbacks.hashCode()); result = prime * result + ((functions == null) ? 0 : functions.hashCode()); return result; @@ -321,6 +342,13 @@ else if (!maxOutputTokens.equals(other.maxOutputTokens)) } else if (!model.equals(other.model)) return false; + if (responseMimeType == null) { + if (other.responseMimeType != null) + return false; + } + else if (!responseMimeType.equals(other.responseMimeType)) { + return false; + } if (functionCallbacks == null) { if (other.functionCallbacks != null) return false; @@ -340,8 +368,9 @@ else if (!functions.equals(other.functions)) public String toString() { return "VertexAiGeminiChatOptions [stopSequences=" + stopSequences + ", temperature=" + temperature + ", topP=" + topP + ", topK=" + topK + ", candidateCount=" + candidateCount + ", maxOutputTokens=" - + maxOutputTokens + ", model=" + model + ", functionCallbacks=" + functionCallbacks + ", functions=" - + functions + ", getClass()=" + getClass() + ", getStopSequences()=" + getStopSequences() + + maxOutputTokens + ", model=" + model + ", responseMimeType=" + responseMimeType + + ", functionCallbacks=" + functionCallbacks + ", functions=" + functions + + ", getClass()=" + getClass() + ", getStopSequences()=" + getStopSequences() + ", getTemperature()=" + getTemperature() + ", getTopP()=" + getTopP() + ", getTopK()=" + getTopK() + ", getCandidateCount()=" + getCandidateCount() + ", getMaxOutputTokens()=" + getMaxOutputTokens() + ", getModel()=" + getModel() + ", getFunctionCallbacks()=" + getFunctionCallbacks() @@ -364,6 +393,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setMaxOutputTokens(fromOptions.getMaxOutputTokens()); options.setModel(fromOptions.getModel()); options.setFunctionCallbacks(fromOptions.getFunctionCallbacks()); + options.setResponseMimeType(fromOptions.getResponseMimeType()); options.setFunctions(fromOptions.getFunctions()); return options; } From b7e3db10b6bd48a8908c6b3aa0451e79a59e3e14 Mon Sep 17 00:00:00 2001 From: KimMinjeong Date: Thu, 8 Aug 2024 01:23:23 +0900 Subject: [PATCH 2/2] feat: add generationConfig test for the VertexAiGeminiChatMoel.toGenerationConfig method --- .../gemini/CreateGeminiRequestTests.java | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java index 9deca852b67..3b5cc1a5721 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java @@ -191,4 +191,35 @@ public void defaultOptionsTools() { .isEqualTo("Overridden function description"); } + @Test + public void createRequestWithGenerationConfigOptions() { + + var client = new VertexAiGeminiChatModel(vertexAI, + VertexAiGeminiChatOptions.builder() + .withModel("DEFAULT_MODEL") + .withTemperature(66.6f) + .withMaxOutputTokens(100) + .withTopK(10.0f) + .withTopP(5.0f) + .withStopSequences(List.of("stop1", "stop2")) + .withCandidateCount(1) + .withResponseMimeType("application/json") + .build()); + + GeminiRequest request = client.createGeminiRequest(new Prompt("Test message content")); + + assertThat(request.contents()).hasSize(1); + + assertThat(request.model().getSystemInstruction()).isNotPresent(); + assertThat(request.model().getModelName()).isEqualTo("DEFAULT_MODEL"); + assertThat(request.model().getGenerationConfig().getTemperature()).isEqualTo(66.6f); + assertThat(request.model().getGenerationConfig().getMaxOutputTokens()).isEqualTo(100); + assertThat(request.model().getGenerationConfig().getTopK()).isEqualTo(10.0f); + assertThat(request.model().getGenerationConfig().getTopP()).isEqualTo(5.0f); + assertThat(request.model().getGenerationConfig().getCandidateCount()).isEqualTo(1); + assertThat(request.model().getGenerationConfig().getStopSequences(0)).isEqualTo("stop1"); + assertThat(request.model().getGenerationConfig().getStopSequences(1)).isEqualTo("stop2"); + assertThat(request.model().getGenerationConfig().getResponseMimeType()).isEqualTo("application/json"); + } + }