Skip to content

Commit b5ca950

Browse files
authored
add api
1 parent f5cad0c commit b5ca950

File tree

1 file changed

+351
-0
lines changed

1 file changed

+351
-0
lines changed

ChatBot.java

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
import com.google.gson.Gson;
2+
import com.google.gson.GsonBuilder;
3+
import com.google.gson.JsonObject;
4+
import com.google.gson.JsonParser;
5+
import com.google.gson.annotations.SerializedName;
6+
7+
import java.io.IOException;
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
import java.util.concurrent.TimeUnit;
11+
import java.util.stream.Collectors;
12+
import java.util.stream.StreamSupport;
13+
14+
import okhttp3.MediaType;
15+
import okhttp3.OkHttpClient;
16+
import okhttp3.Request;
17+
import okhttp3.RequestBody;
18+
import okhttp3.Response;
19+
20+
/**
21+
* The ChatBot class wraps the OpenAI API and lets you send messages and
22+
* receive responses. For more information on how this works, check out
23+
* the <a href="https://platform.openai.com/docs/api-reference/completions">OpenAI Documentation</a>).
24+
*/
25+
public class ChatBot {
26+
27+
private final OkHttpClient client;
28+
private final MediaType mediaType;
29+
private final Gson gson;
30+
private final String apiKey;
31+
32+
/**
33+
* Constructor requires your private API key.
34+
*
35+
* @param apiKey Your OpenAI API key that starts with "sk-".
36+
*/
37+
public ChatBot(String apiKey) {
38+
this.apiKey = apiKey;
39+
this.client = new OkHttpClient.Builder()
40+
.connectTimeout(30, TimeUnit.SECONDS)
41+
.readTimeout(30, TimeUnit.SECONDS).build();
42+
this.mediaType = MediaType.get("application/json; charset=utf-8");
43+
this.gson = (new GsonBuilder()).create();
44+
}
45+
46+
/**
47+
* Blocks the current thread until OpenAI responds to https request. The
48+
* returned value includes information including tokens, generated text,
49+
* and stop reason. You can access the generated message through
50+
* {@link ChatCompletionResponse#getChoices()}.
51+
*
52+
* @param request The input information for ChatGPT.
53+
* @return The returned response.
54+
* @throws IOException If an IO Exception occurs.
55+
* @throws IllegalArgumentException If the input arguments are invalid.
56+
*/
57+
public ChatCompletionResponse generateResponse(ChatCompletionRequest request) throws IOException {
58+
String json = this.gson.toJson(request);
59+
RequestBody body = RequestBody.create(json, this.mediaType);
60+
Request httpRequest = (new Request.Builder()).url("https://api.openai.com/v1/chat/completions")
61+
.addHeader("Content-Type", "application/json")
62+
.addHeader("Authorization", "Bearer " + this.apiKey)
63+
.post(body).build();
64+
65+
// Save the JsonObject to check for errors
66+
JsonObject rootObject = null;
67+
try (Response response = this.client.newCall(httpRequest).execute()) {
68+
69+
// Servers respond to API calls with json blocks. Since raw JSON isn't
70+
// very developer friendly, we wrap for easy data access.
71+
rootObject = JsonParser.parseString(response.body().string()).getAsJsonObject();
72+
if (rootObject.has("error"))
73+
throw new IllegalArgumentException(rootObject.get("error").getAsJsonObject().get("message").getAsString());
74+
75+
return new ChatCompletionResponse(rootObject);
76+
} catch (Throwable ex) {
77+
System.err.println("Some error occurred whilst using the Chat Completion API");
78+
System.err.println("Request:\n\n" + json);
79+
System.err.println("\nRoot Object:\n\n" + rootObject);
80+
throw ex;
81+
}
82+
}
83+
84+
85+
/**
86+
* The ChatGPT API takes a list of 'roles' and 'content'. The role is
87+
* one of 3 options: system, assistant, and user. 'System' is used to
88+
* prompt ChatGPT before the user gives input. 'Assistant' is a message
89+
* from ChatGPT. 'User' is a message from the human.
90+
*/
91+
public static final class ChatMessage {
92+
93+
private final String role;
94+
private final String content;
95+
96+
/**
97+
* Constructor requires who sent the message, and the content of the
98+
* message.
99+
*
100+
* @param role Who sent the message.
101+
* @param content The raw content of the message.
102+
*/
103+
public ChatMessage(String role, String content) {
104+
this.role = role;
105+
this.content = content;
106+
}
107+
108+
public ChatMessage(JsonObject json) {
109+
this(json.get("role").getAsString(), json.get("content").getAsString());
110+
}
111+
112+
public String getRole() {
113+
return this.role;
114+
}
115+
116+
public String getContent() {
117+
return this.content;
118+
}
119+
}
120+
121+
/**
122+
* These are the arguments that control the result of the output. For more
123+
* information, refer to the <a href="https://platform.openai.com/docs/api-reference/completions/create">OpenAI Docs</a>.
124+
*/
125+
public static final class ChatCompletionRequest {
126+
127+
private final String model;
128+
private final List<ChatMessage> messages;
129+
private final float temperature;
130+
@SerializedName("top_p")
131+
private final float topP;
132+
private final int n;
133+
private final boolean stream;
134+
private final String stop;
135+
@SerializedName("max_tokens")
136+
private final Integer maxTokens;
137+
@SerializedName("presence_penalty")
138+
private final float presencePenalty;
139+
@SerializedName("frequency_penalty")
140+
private final float frequencyPenalty;
141+
@SerializedName("logit_bias")
142+
private final JsonObject logitBias;
143+
private final String user;
144+
145+
/**
146+
* Shorthand constructor for {@link #ChatCompletionRequest(String, List, float, float, int, boolean, String, Integer, float, float, JsonObject, String)}
147+
*
148+
* @param model The model to use to generate the text. Recommended: "gpt-3.5-turbo"
149+
* @param messages All previous messages from the conversation.
150+
*/
151+
public ChatCompletionRequest(String model, List<ChatMessage> messages) {
152+
this(model, messages, 1.0f, 1.0f, 1, false, null, null, 0f, 0f, null, null);
153+
}
154+
155+
/**
156+
* @param model The model used to generate the text. Recommended: "gpt-3.5-turbo."
157+
* @param messages All previous messages from the conversation.
158+
* @param temperature How "creative" the results are. [0.0, 2.0].
159+
* @param topP Controls how "on topic" the tokens are.
160+
* @param n Controls how many responses to generate. Numbers >1 will chew through your tokens.
161+
* @param stream <b>UNTESTED</b> recommend keeping this false.
162+
* @param stop The sequence used to stop generating tokens.
163+
* @param maxTokens The maximum number of tokens to use.
164+
* @param presencePenalty Prevent talking about duplicate topics.
165+
* @param frequencyPenalty Prevent repeating the same text.
166+
* @param logitBias Control specific tokens from being used.
167+
* @param user Who send this request (for moderation).
168+
*/
169+
public ChatCompletionRequest(String model, List<ChatMessage> messages, float temperature, float topP, int n, boolean stream, String stop, Integer maxTokens, float presencePenalty, float frequencyPenalty, JsonObject logitBias, String user) {
170+
this.model = model;
171+
this.messages = new ArrayList<>(messages); // Use a mutable list
172+
this.temperature = temperature;
173+
this.topP = topP;
174+
this.n = n;
175+
this.stream = stream;
176+
this.stop = stop;
177+
this.maxTokens = maxTokens;
178+
this.presencePenalty = presencePenalty;
179+
this.frequencyPenalty = frequencyPenalty;
180+
this.logitBias = logitBias;
181+
this.user = user;
182+
}
183+
184+
public String getModel() {
185+
return this.model;
186+
}
187+
188+
public List<ChatMessage> getMessages() {
189+
return this.messages;
190+
}
191+
192+
public float getTemperature() {
193+
return this.temperature;
194+
}
195+
196+
public float getTopP() {
197+
return this.topP;
198+
}
199+
200+
public int getN() {
201+
return this.n;
202+
}
203+
204+
public boolean getStream() {
205+
return this.stream;
206+
}
207+
208+
public String getStop() {
209+
return this.stop;
210+
}
211+
212+
public Integer getMaxTokens() {
213+
return this.maxTokens;
214+
}
215+
216+
public float getPresencePenalty() {
217+
return this.presencePenalty;
218+
}
219+
220+
public float getFrequencyPenalty() {
221+
return this.frequencyPenalty;
222+
}
223+
224+
public JsonObject getLogitBias() {
225+
return this.logitBias;
226+
}
227+
228+
public String getUser() {
229+
return this.user;
230+
}
231+
}
232+
233+
/**
234+
* This is the object returned from the API. You want to access choices[0]
235+
* to get your response.
236+
*/
237+
public static final class ChatCompletionResponse {
238+
239+
private final String id;
240+
private final String object;
241+
private final long created;
242+
private final List<ChatCompletionChoice> choices;
243+
private final ChatCompletionUsage usage;
244+
245+
public ChatCompletionResponse(String id, String object, long created, List<ChatCompletionChoice> choices, ChatCompletionUsage usage) {
246+
super();
247+
this.id = id;
248+
this.object = object;
249+
this.created = created;
250+
this.choices = choices;
251+
this.usage = usage;
252+
}
253+
254+
public ChatCompletionResponse(JsonObject json) {
255+
this(json.get("id").getAsString(), json.get("object").getAsString(), json.get("created").getAsLong(), StreamSupport.stream(json.get("choices").getAsJsonArray().spliterator(), false).map(element -> new ChatCompletionChoice(element.getAsJsonObject())).collect(Collectors.toList()), new ChatCompletionUsage(json.get("usage").getAsJsonObject()));
256+
}
257+
258+
public String getId() {
259+
return this.id;
260+
}
261+
262+
public String getObject() {
263+
return this.object;
264+
}
265+
266+
public long getCreated() {
267+
return this.created;
268+
}
269+
270+
public List<ChatCompletionChoice> getChoices() {
271+
return this.choices;
272+
}
273+
274+
public ChatCompletionUsage getUsage() {
275+
return this.usage;
276+
}
277+
}
278+
279+
public static final class ChatCompletionChoice {
280+
281+
private final int index;
282+
private final ChatMessage message;
283+
private final String finishReason;
284+
285+
/**
286+
* Holds the data for 1 generated text completion.
287+
*
288+
* @param index The index in the array... 0 if n=1.
289+
* @param message The generated text.
290+
* @param finishReason Why did the bot stop generating tokens?
291+
*/
292+
public ChatCompletionChoice(int index, ChatMessage message, String finishReason) {
293+
super();
294+
this.index = index;
295+
this.message = message;
296+
this.finishReason = finishReason;
297+
}
298+
299+
public ChatCompletionChoice(JsonObject json) {
300+
this(json.get("index").getAsInt(), new ChatMessage(json.get("message").getAsJsonObject()), json.get("finish_reason").toString());
301+
}
302+
303+
public int getIndex() {
304+
return this.index;
305+
}
306+
307+
public ChatMessage getMessage() {
308+
return this.message;
309+
}
310+
311+
public String getFinishReason() {
312+
return this.finishReason;
313+
}
314+
}
315+
316+
public static final class ChatCompletionUsage {
317+
private final int promptTokens;
318+
private final int completionTokens;
319+
private final int totalTokens;
320+
321+
/**
322+
* Holds how many tokens that were used by your API request. Use these
323+
* tokens to calculate how much money you have spent on each request.
324+
*
325+
* @param promptTokens How many tokens the input used.
326+
* @param completionTokens How many tokens the output used.
327+
* @param totalTokens How many tokens in total.
328+
*/
329+
public ChatCompletionUsage(int promptTokens, int completionTokens, int totalTokens) {
330+
this.promptTokens = promptTokens;
331+
this.completionTokens = completionTokens;
332+
this.totalTokens = totalTokens;
333+
}
334+
335+
public ChatCompletionUsage(JsonObject json) {
336+
this(json.get("prompt_tokens").getAsInt(), json.get("completion_tokens").getAsInt(), json.get("total_tokens").getAsInt());
337+
}
338+
339+
public int getPromptTokens() {
340+
return this.promptTokens;
341+
}
342+
343+
public int getCompletionTokens() {
344+
return this.completionTokens;
345+
}
346+
347+
public int getTotalTokens() {
348+
return this.totalTokens;
349+
}
350+
}
351+
}

0 commit comments

Comments
 (0)