Skip to content

Commit dad30f0

Browse files
Add ChatOptions to PromptTemplate create methods
Enable ChatOptions configuration when creating Prompts from Templates, Co-authored-by: zhangqian9158 <zhangqian9158@users.noreply.github.com>
1 parent 528dc04 commit dad30f0

File tree

4 files changed

+78
-0
lines changed

4 files changed

+78
-0
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatPromptTemplate.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,22 @@ public Prompt create() {
7575
return new Prompt(messages);
7676
}
7777

78+
@Override
79+
public Prompt create(ChatOptions modelOptions) {
80+
List<Message> messages = createMessages();
81+
return new Prompt(messages, modelOptions);
82+
}
83+
7884
@Override
7985
public Prompt create(Map<String, Object> model) {
8086
List<Message> messages = createMessages(model);
8187
return new Prompt(messages);
8288
}
8389

90+
@Override
91+
public Prompt create(Map<String, Object> model, ChatOptions modelOptions) {
92+
List<Message> messages = createMessages(model);
93+
return new Prompt(messages, modelOptions);
94+
}
95+
8496
}

spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplate.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,21 @@ public Prompt create() {
172172
return new Prompt(render(new HashMap<>()));
173173
}
174174

175+
@Override
176+
public Prompt create(ChatOptions modelOptions) {
177+
return new Prompt(render(new HashMap<>()), modelOptions);
178+
}
179+
175180
@Override
176181
public Prompt create(Map<String, Object> model) {
177182
return new Prompt(render(model));
178183
}
179184

185+
@Override
186+
public Prompt create(Map<String, Object> model, ChatOptions modelOptions) {
187+
return new Prompt(render(model), modelOptions);
188+
}
189+
180190
public Set<String> getInputVariables() {
181191
TokenStream tokens = this.st.impl.tokens;
182192
Set<String> inputVariables = new HashSet<>();

spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/PromptTemplateActions.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public interface PromptTemplateActions extends PromptTemplateStringActions {
2121

2222
Prompt create();
2323

24+
Prompt create(ChatOptions modelOptions);
25+
2426
Prompt create(Map<String, Object> model);
2527

28+
Prompt create(Map<String, Object> model, ChatOptions modelOptions);
29+
2630
}

spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import org.junit.jupiter.api.Disabled;
1919
import org.junit.jupiter.api.Test;
2020
import org.springframework.ai.chat.messages.Message;
21+
import org.springframework.ai.chat.prompt.ChatOptions;
22+
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
23+
import org.springframework.ai.chat.prompt.Prompt;
2124
import org.springframework.ai.chat.prompt.PromptTemplate;
2225
import org.springframework.core.io.InputStreamResource;
2326
import org.springframework.core.io.Resource;
@@ -31,11 +34,60 @@
3134
import java.util.Map;
3235

3336
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
37+
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
3438
import static org.junit.jupiter.api.Assertions.assertEquals;
3539
import static org.junit.jupiter.api.Assertions.assertThrows;
3640

3741
public class PromptTemplateTest {
3842

43+
@Test
44+
public void testCreateWithEmptyModelAndChatOptions() {
45+
String template = "This is a test prompt with no variables";
46+
PromptTemplate promptTemplate = new PromptTemplate(template);
47+
ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(0.7f).withTopK(3).build();
48+
49+
Prompt prompt = promptTemplate.create(chatOptions);
50+
51+
assertThat(prompt).isNotNull();
52+
assertThat(prompt.getContents()).isEqualTo(template);
53+
assertThat(prompt.getOptions()).isEqualTo(chatOptions);
54+
}
55+
56+
@Test
57+
public void testCreateWithModelAndChatOptions() {
58+
String template = "Hello, {name}! Your age is {age}.";
59+
Map<String, Object> model = new HashMap<>();
60+
model.put("name", "Alice");
61+
model.put("age", 30);
62+
PromptTemplate promptTemplate = new PromptTemplate(template, model);
63+
ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(0.5f).withMaxTokens(100).build();
64+
65+
Prompt prompt = promptTemplate.create(model, chatOptions);
66+
67+
assertThat(prompt).isNotNull();
68+
assertThat(prompt.getContents()).isEqualTo("Hello, Alice! Your age is 30.");
69+
assertThat(prompt.getOptions()).isEqualTo(chatOptions);
70+
}
71+
72+
@Test
73+
public void testCreateWithOverriddenModelAndChatOptions() {
74+
String template = "Hello, {name}! Your favorite color is {color}.";
75+
Map<String, Object> initialModel = new HashMap<>();
76+
initialModel.put("name", "Bob");
77+
initialModel.put("color", "blue");
78+
PromptTemplate promptTemplate = new PromptTemplate(template, initialModel);
79+
80+
Map<String, Object> overriddenModel = new HashMap<>();
81+
overriddenModel.put("color", "red");
82+
ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(0.8f).build();
83+
84+
Prompt prompt = promptTemplate.create(overriddenModel, chatOptions);
85+
86+
assertThat(prompt).isNotNull();
87+
assertThat(prompt.getContents()).isEqualTo("Hello, Bob! Your favorite color is red.");
88+
assertThat(prompt.getOptions()).isEqualTo(chatOptions);
89+
}
90+
3991
@Test
4092
public void testRenderWithList() {
4193
String templateString = "The items are:\n{items:{item | - {item}\n}}";

0 commit comments

Comments
 (0)