66import time
77import traceback
88from datetime import datetime
9- from typing import Any , Dict , List , Optional
9+ from typing import Any , Dict , List , Optional , Union
1010
1111import shortuuid
1212from pydantic import BaseModel
@@ -486,7 +486,7 @@ async def get_setting(setting: str):
486486
487487 class ChatMessage (BaseModel ):
488488 role : str
489- content : str
489+ content : Union [ str , List [ Dict [ str , Any ]]]
490490
491491 class ChatCompletionRequest (BaseModel ):
492492 model : str = "default-model"
@@ -495,10 +495,8 @@ class ChatCompletionRequest(BaseModel):
495495 temperature : Optional [float ] = None
496496 stream : Optional [bool ] = False
497497
498- async def openai_compatible_generator (message : str ):
499- for i , chunk in enumerate (
500- async_interpreter .chat (message , stream = True , display = True )
501- ):
498+ async def openai_compatible_generator ():
499+ for i , chunk in enumerate (async_interpreter ._respond_and_store ()):
502500 output_content = None
503501
504502 if chunk ["type" ] == "message" and "content" in chunk :
@@ -519,17 +517,56 @@ async def openai_compatible_generator(message: str):
519517
520518 @router .post ("/openai/chat/completions" )
521519 async def chat_completion (request : ChatCompletionRequest ):
522- assert request .messages [- 1 ].role == "user"
523- message = request .messages [- 1 ].content
520+ # Convert to LMC
521+
522+ user_messages = []
523+ for message in reversed (request .messages ):
524+ if message .role == "user" :
525+ user_messages .append (message )
526+ else :
527+ break
528+ user_messages .reverse ()
529+
530+ for message in user_messages :
531+ if type (message .content ) == str :
532+ async_interpreter .messages .append (
533+ {"role" : "user" , "type" : "message" , "content" : message .content }
534+ )
535+ if type (message .content ) == list :
536+ for content in message .content :
537+ if content ["type" ] == "text" :
538+ async_interpreter .messages .append (
539+ {"role" : "user" , "type" : "message" , "content" : content }
540+ )
541+ elif content ["type" ] == "image_url" :
542+ if "url" not in content ["image_url" ]:
543+ raise Exception ("`url` must be in `image_url`." )
544+ url = content ["image_url" ]["url" ]
545+ print (url [:100 ])
546+ if "base64," not in url :
547+ raise Exception (
548+ '''Image must be in the format: "data:image/jpeg;base64,{base64_image}"'''
549+ )
550+
551+ # data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAA6oA...
552+
553+ data = url .split ("base64," )[1 ]
554+ format = "base64." + url .split (";" )[0 ].split ("/" )[1 ]
555+ async_interpreter .messages .append (
556+ {
557+ "role" : "user" ,
558+ "type" : "image" ,
559+ "format" : format ,
560+ "content" : data ,
561+ }
562+ )
524563
525564 if request .stream :
526565 return StreamingResponse (
527- openai_compatible_generator (message ), media_type = "application/x-ndjson"
566+ openai_compatible_generator (), media_type = "application/x-ndjson"
528567 )
529568 else :
530- messages = async_interpreter .chat (
531- message = message , stream = False , display = False
532- )
569+ messages = async_interpreter .chat (message = "" , stream = False , display = True )
533570 content = messages [- 1 ]["content" ]
534571 return {
535572 "id" : "200" ,
0 commit comments