Skip to content

Commit 624b6c7

Browse files
kabirkhanalejandrojnmCopilotkrissetto
authored
Add file input support to Chat Completion types (#4)
* Add file part support to chat message structure Introduces ChatMessagePartFile struct and ChatMessagePartTypeFile constant to support file attachments in chat messages. Updates ChatMessagePart to include file parts and adds comprehensive tests for serialization, deserialization, and constant definitions. * Rename ChatMessagePartFile to ChatMessageFile Refactored struct name for file parts in chat messages from ChatMessagePartFile to ChatMessageFile for consistency and clarity. * Fix indentation in ChatMessagePart struct Corrected the indentation of the File field in the ChatMessagePart struct for improved code readability and consistency. * Update tests to use ChatMessageFile type Replaces usage of ChatMessagePartFile with ChatMessageFile in chat_test.go to reflect updated type naming in the openai package. Also renames related test function for consistency. * Fix formatting in multipart chat message test Split a long conditional statement in TestMultipartChatMessageSerialization for improved readability. * Update chat.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Stop stripping dots in azure model mapper for models that aren't 3.5 based (sashabaranov#1079) fixes sashabaranov#978 Signed-off-by: Christopher Petito <chrisjpetito@gmail.com> --------- Signed-off-by: Christopher Petito <chrisjpetito@gmail.com> Co-authored-by: Alejandro J. Nuñez Madrazo <alejandrojnm@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Christopher Petito <47751006+krissetto@users.noreply.github.com>
1 parent 4158511 commit 624b6c7

File tree

4 files changed

+131
-3
lines changed

4 files changed

+131
-3
lines changed

chat.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,26 @@ type ChatMessageImageURL struct {
8181
Detail ImageURLDetail `json:"detail,omitempty"`
8282
}
8383

84+
// ChatMessageFile is a placeholder for file parts in chat messages.
85+
type ChatMessageFile struct {
86+
FileID string `json:"file_id,omitempty"`
87+
FileName string `json:"filename,omitempty"`
88+
FileData string `json:"file_data,omitempty"` // Base64 encoded file data
89+
}
90+
8491
type ChatMessagePartType string
8592

8693
const (
8794
ChatMessagePartTypeText ChatMessagePartType = "text"
8895
ChatMessagePartTypeImageURL ChatMessagePartType = "image_url"
96+
ChatMessagePartTypeFile ChatMessagePartType = "file"
8997
)
9098

9199
type ChatMessagePart struct {
92100
Type ChatMessagePartType `json:"type,omitempty"`
93101
Text string `json:"text,omitempty"`
94102
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
103+
File *ChatMessageFile `json:"file,omitempty"`
95104
}
96105

97106
type ChatCompletionMessage struct {

chat_test.go

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,14 @@ func TestMultipartChatCompletions(t *testing.T) {
797797
Detail: openai.ImageURLDetailLow,
798798
},
799799
},
800+
{
801+
Type: openai.ChatMessagePartTypeFile,
802+
File: &openai.ChatMessageFile{
803+
FileID: "file-123",
804+
FileName: "test.txt",
805+
FileData: "dGVzdCBmaWxlIGNvbnRlbnQ=", // base64 encoded "test file content"
806+
},
807+
},
800808
},
801809
},
802810
},
@@ -807,7 +815,8 @@ func TestMultipartChatCompletions(t *testing.T) {
807815
func TestMultipartChatMessageSerialization(t *testing.T) {
808816
jsonText := `[{"role":"system","content":"system-message"},` +
809817
`{"role":"user","content":[{"type":"text","text":"nice-text"},` +
810-
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]`
818+
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}},` +
819+
`{"type":"file","file":{"file_id":"file-123","filename":"test.txt","file_data":"dGVzdA=="}}]}]`
811820

812821
var msgs []openai.ChatCompletionMessage
813822
err := json.Unmarshal([]byte(jsonText), &msgs)
@@ -820,7 +829,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
820829
if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil {
821830
t.Errorf("invalid user message: %v", msgs[0])
822831
}
823-
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 {
832+
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 3 {
824833
t.Errorf("invalid user message")
825834
}
826835
parts := msgs[1].MultiContent
@@ -830,6 +839,10 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
830839
if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" {
831840
t.Errorf("invalid image_url part")
832841
}
842+
if parts[2].Type != "file" || parts[2].File.FileID != "file-123" ||
843+
parts[2].File.FileName != "test.txt" || parts[2].File.FileData != "dGVzdA==" {
844+
t.Errorf("invalid file part: %v", parts[2])
845+
}
833846

834847
s, err := json.Marshal(msgs)
835848
if err != nil {
@@ -876,6 +889,103 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
876889
}
877890
}
878891

892+
func TestChatMessageFile(t *testing.T) {
893+
// Test file part with FileID
894+
filePart := openai.ChatMessagePart{
895+
Type: openai.ChatMessagePartTypeFile,
896+
File: &openai.ChatMessageFile{
897+
FileID: "file-abc123",
898+
},
899+
}
900+
901+
// Test serialization
902+
data, err := json.Marshal(filePart)
903+
if err != nil {
904+
t.Fatalf("Expected no error: %s", err)
905+
}
906+
907+
expected := `{"type":"file","file":{"file_id":"file-abc123"}}`
908+
result := strings.ReplaceAll(string(data), " ", "")
909+
if result != expected {
910+
t.Errorf("Expected %s, got %s", expected, result)
911+
}
912+
913+
// Test deserialization
914+
var parsedPart openai.ChatMessagePart
915+
err = json.Unmarshal(data, &parsedPart)
916+
if err != nil {
917+
t.Fatalf("Expected no error: %s", err)
918+
}
919+
920+
if parsedPart.Type != openai.ChatMessagePartTypeFile {
921+
t.Errorf("Expected type %s, got %s", openai.ChatMessagePartTypeFile, parsedPart.Type)
922+
}
923+
if parsedPart.File == nil {
924+
t.Fatal("Expected File to be non-nil")
925+
}
926+
if parsedPart.File.FileID != "file-abc123" {
927+
t.Errorf("Expected FileID %s, got %s", "file-abc123", parsedPart.File.FileID)
928+
}
929+
930+
// Test file part with all fields
931+
filePartComplete := openai.ChatMessagePart{
932+
Type: openai.ChatMessagePartTypeFile,
933+
File: &openai.ChatMessageFile{
934+
FileID: "file-xyz789",
935+
FileName: "document.pdf",
936+
FileData: "JVBERi0xLjQK", // base64 for "%PDF-1.4\n"
937+
},
938+
}
939+
940+
data, err = json.Marshal(filePartComplete)
941+
if err != nil {
942+
t.Fatalf("Expected no error: %s", err)
943+
}
944+
945+
expected = `{"type":"file","file":{"file_id":"file-xyz789","filename":"document.pdf","file_data":"JVBERi0xLjQK"}}`
946+
result = strings.ReplaceAll(string(data), " ", "")
947+
if result != expected {
948+
t.Errorf("Expected %s, got %s", expected, result)
949+
}
950+
951+
// Test deserialization of complete file part
952+
var parsedCompleteFile openai.ChatMessagePart
953+
err = json.Unmarshal(data, &parsedCompleteFile)
954+
if err != nil {
955+
t.Fatalf("Expected no error: %s", err)
956+
}
957+
958+
if parsedCompleteFile.File.FileID != "file-xyz789" {
959+
t.Errorf("Expected FileID %s, got %s", "file-xyz789", parsedCompleteFile.File.FileID)
960+
}
961+
if parsedCompleteFile.File.FileName != "document.pdf" {
962+
t.Errorf("Expected FileName %s, got %s", "document.pdf", parsedCompleteFile.File.FileName)
963+
}
964+
if parsedCompleteFile.File.FileData != "JVBERi0xLjQK" {
965+
t.Errorf("Expected FileData %s, got %s", "JVBERi0xLjQK", parsedCompleteFile.File.FileData)
966+
}
967+
}
968+
969+
func TestChatMessagePartTypeConstants(t *testing.T) {
970+
// Test that the new file constant is properly defined
971+
if openai.ChatMessagePartTypeFile != "file" {
972+
t.Errorf("Expected ChatMessagePartTypeFile to be 'file', got %s", openai.ChatMessagePartTypeFile)
973+
}
974+
975+
// Test all part type constants
976+
expectedTypes := map[openai.ChatMessagePartType]string{
977+
openai.ChatMessagePartTypeText: "text",
978+
openai.ChatMessagePartTypeImageURL: "image_url",
979+
openai.ChatMessagePartTypeFile: "file",
980+
}
981+
982+
for constant, expected := range expectedTypes {
983+
if string(constant) != expected {
984+
t.Errorf("Expected %s to be %s, got %s", constant, expected, string(constant))
985+
}
986+
}
987+
}
988+
879989
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
880990
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
881991
var err error

config.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openai
33
import (
44
"net/http"
55
"regexp"
6+
"strings"
67
)
78

89
const (
@@ -70,7 +71,11 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
7071
APIType: APITypeAzure,
7172
APIVersion: "2023-05-15",
7273
AzureModelMapperFunc: func(model string) string {
73-
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
74+
// only 3.5 models have the "." stripped in their names
75+
if strings.Contains(model, "3.5") {
76+
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
77+
}
78+
return strings.ReplaceAll(model, ":", "")
7479
},
7580

7681
HTTPClient: &http.Client{},

config_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
2020
Model: "gpt-3.5-turbo-0301",
2121
Expect: "gpt-35-turbo-0301",
2222
},
23+
{
24+
Model: "gpt-4.1",
25+
Expect: "gpt-4.1",
26+
},
2327
{
2428
Model: "text-embedding-ada-002",
2529
Expect: "text-embedding-ada-002",

0 commit comments

Comments
 (0)