@@ -8,16 +8,21 @@ import (
88 "log"
99 "net/http"
1010 "net/http/httputil"
11- "path"
1211 "strings"
1312
1413 "github.com/bytedance/sonic"
1514 "github.com/gin-gonic/gin"
1615 "github.com/pkg/errors"
1716)
1817
18+ func ProxyWithConverter (requestConverter RequestConverter ) gin.HandlerFunc {
19+ return func (c * gin.Context ) {
20+ Proxy (c , requestConverter )
21+ }
22+ }
23+
1924// Proxy Azure OpenAI
20- func Proxy (c * gin.Context ) {
25+ func Proxy (c * gin.Context , requestConverter RequestConverter ) {
2126 if c .Request .Method == http .MethodOptions {
2227 c .Header ("Access-Control-Allow-Origin" , "*" )
2328 c .Header ("Access-Control-Allow-Methods" , "GET, OPTIONS, POST" )
@@ -34,38 +39,48 @@ func Proxy(c *gin.Context) {
3439 body , _ := io .ReadAll (req .Body )
3540 req .Body = io .NopCloser (bytes .NewBuffer (body ))
3641
37- // get model from body
38- model , err := sonic .Get (body , "model" )
39- if err != nil {
40- util .SendError (c , errors .Wrap (err , "get model error" ))
41- return
42+ // get model from url params or body
43+ model := c .Param ("model" )
44+ if model == "" {
45+ _model , err := sonic .Get (body , "model" )
46+ if err != nil {
47+ util .SendError (c , errors .Wrap (err , "get model error" ))
48+ return
49+ }
50+ _modelStr , err := _model .String ()
51+ if err != nil {
52+ util .SendError (c , errors .Wrap (err , "get model name error" ))
53+ return
54+ }
55+ model = _modelStr
4256 }
4357
4458 // get deployment from request
45- deployment , err := model . String ( )
59+ deployment , err := GetDeploymentByModel ( model )
4660 if err != nil {
47- util .SendError (c , errors . Wrap ( err , "get deployment error" ) )
61+ util .SendError (c , err )
4862 return
4963 }
50- deployment = GetDeploymentByModel (deployment )
5164
52- // get auth token from header
53- rawToken := req .Header .Get ("Authorization" )
54- token := strings .TrimPrefix (rawToken , "Bearer " )
65+ // get auth token from header or deployemnt config
66+ token := deployment .ApiKey
67+ if token == "" {
68+ rawToken := req .Header .Get ("Authorization" )
69+ token = strings .TrimPrefix (rawToken , "Bearer " )
70+ }
71+ if token == "" {
72+ util .SendError (c , errors .New ("token is empty" ))
73+ return
74+ }
5575 req .Header .Set (AuthHeaderKey , token )
5676 req .Header .Del ("Authorization" )
5777
5878 originURL := req .URL .String ()
59- req .Host = AzureOpenAIEndpointParse .Host
60- req .URL .Scheme = AzureOpenAIEndpointParse .Scheme
61- req .URL .Host = AzureOpenAIEndpointParse .Host
62- req .URL .Path = path .Join (fmt .Sprintf ("/openai/deployments/%s" , deployment ), strings .Replace (req .URL .Path , "/v1/" , "/" , 1 ))
63- req .URL .RawPath = req .URL .EscapedPath ()
64-
65- query := req .URL .Query ()
66- query .Add ("api-version" , AzureOpenAIAPIVer )
67- req .URL .RawQuery = query .Encode ()
68-
79+ req , err = requestConverter .Convert (req , deployment )
80+ if err != nil {
81+ util .SendError (c , errors .Wrap (err , "convert request error" ))
82+ return
83+ }
6984 log .Printf ("proxying request [%s] %s -> %s" , model , originURL , req .URL .String ())
7085 }
7186
@@ -80,10 +95,10 @@ func Proxy(c *gin.Context) {
8095 }
8196}
8297
83- func GetDeploymentByModel (model string ) string {
84- if v , ok := AzureOpenAIModelMapper [model ]; ok {
85- return v
98+ func GetDeploymentByModel (model string ) (* DeploymentConfig , error ) {
99+ deploymentConfig , exist := ModelDeploymentConfig [model ]
100+ if ! exist {
101+ return nil , errors .New (fmt .Sprintf ("deployment config for %s not found" , model ))
86102 }
87-
88- return fallbackModelMapper .ReplaceAllString (model , "" )
103+ return & deploymentConfig , nil
89104}
0 commit comments