22import json
33from threading import Lock
44from typing import List , Optional , Union , Iterator , Dict
5- from typing_extensions import TypedDict , Literal
5+ from typing_extensions import TypedDict , Literal , Annotated
66
77import llama_cpp
88
9- from fastapi import Depends , FastAPI
9+ from fastapi import Depends , FastAPI , APIRouter
1010from fastapi .middleware .cors import CORSMiddleware
1111from pydantic import BaseModel , BaseSettings , Field , create_model_from_typeddict
1212from sse_starlette .sse import EventSourceResponse
1313
1414
1515class Settings (BaseSettings ):
16- model : str = os . environ . get ( "MODEL" , "null" )
16+ model : str
1717 n_ctx : int = 2048
1818 n_batch : int = 512
1919 n_threads : int = max ((os .cpu_count () or 2 ) // 2 , 1 )
@@ -27,25 +27,29 @@ class Settings(BaseSettings):
2727 vocab_only : bool = False
2828
2929
30- app = FastAPI (
31- title = "🦙 llama.cpp Python API" ,
32- version = "0.0.1" ,
33- )
34- app .add_middleware (
35- CORSMiddleware ,
36- allow_origins = ["*" ],
37- allow_credentials = True ,
38- allow_methods = ["*" ],
39- allow_headers = ["*" ],
40- )
30+ router = APIRouter ()
31+
32+ llama : Optional [llama_cpp .Llama ] = None
4133
42- llama : llama_cpp . Llama = None
43- def init_llama (settings : Settings = None ):
34+
35+ def create_app (settings : Optional [ Settings ] = None ):
4436 if settings is None :
4537 settings = Settings ()
38+ app = FastAPI (
39+ title = "🦙 llama.cpp Python API" ,
40+ version = "0.0.1" ,
41+ )
42+ app .add_middleware (
43+ CORSMiddleware ,
44+ allow_origins = ["*" ],
45+ allow_credentials = True ,
46+ allow_methods = ["*" ],
47+ allow_headers = ["*" ],
48+ )
49+ app .include_router (router )
4650 global llama
4751 llama = llama_cpp .Llama (
48- settings .model ,
52+ model_path = settings .model ,
4953 f16_kv = settings .f16_kv ,
5054 use_mlock = settings .use_mlock ,
5155 use_mmap = settings .use_mmap ,
@@ -60,12 +64,17 @@ def init_llama(settings: Settings = None):
6064 if settings .cache :
6165 cache = llama_cpp .LlamaCache ()
6266 llama .set_cache (cache )
67+ return app
68+
6369
6470llama_lock = Lock ()
71+
72+
6573def get_llama ():
6674 with llama_lock :
6775 yield llama
6876
77+
6978class CreateCompletionRequest (BaseModel ):
7079 prompt : Union [str , List [str ]]
7180 suffix : Optional [str ] = Field (None )
@@ -102,7 +111,7 @@ class Config:
102111CreateCompletionResponse = create_model_from_typeddict (llama_cpp .Completion )
103112
104113
105- @app .post (
114+ @router .post (
106115 "/v1/completions" ,
107116 response_model = CreateCompletionResponse ,
108117)
@@ -148,7 +157,7 @@ class Config:
148157CreateEmbeddingResponse = create_model_from_typeddict (llama_cpp .Embedding )
149158
150159
151- @app .post (
160+ @router .post (
152161 "/v1/embeddings" ,
153162 response_model = CreateEmbeddingResponse ,
154163)
@@ -202,7 +211,7 @@ class Config:
202211CreateChatCompletionResponse = create_model_from_typeddict (llama_cpp .ChatCompletion )
203212
204213
205- @app .post (
214+ @router .post (
206215 "/v1/chat/completions" ,
207216 response_model = CreateChatCompletionResponse ,
208217)
@@ -256,7 +265,7 @@ class ModelList(TypedDict):
256265GetModelResponse = create_model_from_typeddict (ModelList )
257266
258267
259- @app .get ("/v1/models" , response_model = GetModelResponse )
268+ @router .get ("/v1/models" , response_model = GetModelResponse )
260269def get_models () -> ModelList :
261270 return {
262271 "object" : "list" ,
0 commit comments