Skip to content

Commit 9ba3a1e

Browse files
committed
🧱 update router
1 parent 932eefe commit 9ba3a1e

File tree

2 files changed

+54
-31
lines changed

2 files changed

+54
-31
lines changed

azure/proxy.go

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,21 @@ import (
88
"log"
99
"net/http"
1010
"net/http/httputil"
11-
"path"
1211
"strings"
1312

1413
"github.com/bytedance/sonic"
1514
"github.com/gin-gonic/gin"
1615
"github.com/pkg/errors"
1716
)
1817

18+
func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc {
19+
return func(c *gin.Context) {
20+
Proxy(c, requestConverter)
21+
}
22+
}
23+
1924
// Proxy Azure OpenAI
20-
func Proxy(c *gin.Context) {
25+
func Proxy(c *gin.Context, requestConverter RequestConverter) {
2126
if c.Request.Method == http.MethodOptions {
2227
c.Header("Access-Control-Allow-Origin", "*")
2328
c.Header("Access-Control-Allow-Methods", "GET, OPTIONS, POST")
@@ -34,38 +39,48 @@ func Proxy(c *gin.Context) {
3439
body, _ := io.ReadAll(req.Body)
3540
req.Body = io.NopCloser(bytes.NewBuffer(body))
3641

37-
// get model from body
38-
model, err := sonic.Get(body, "model")
39-
if err != nil {
40-
util.SendError(c, errors.Wrap(err, "get model error"))
41-
return
42+
// get model from url params or body
43+
model := c.Param("model")
44+
if model == "" {
45+
_model, err := sonic.Get(body, "model")
46+
if err != nil {
47+
util.SendError(c, errors.Wrap(err, "get model error"))
48+
return
49+
}
50+
_modelStr, err := _model.String()
51+
if err != nil {
52+
util.SendError(c, errors.Wrap(err, "get model name error"))
53+
return
54+
}
55+
model = _modelStr
4256
}
4357

4458
// get deployment from request
45-
deployment, err := model.String()
59+
deployment, err := GetDeploymentByModel(model)
4660
if err != nil {
47-
util.SendError(c, errors.Wrap(err, "get deployment error"))
61+
util.SendError(c, err)
4862
return
4963
}
50-
deployment = GetDeploymentByModel(deployment)
5164

52-
// get auth token from header
53-
rawToken := req.Header.Get("Authorization")
54-
token := strings.TrimPrefix(rawToken, "Bearer ")
65+
// get auth token from header or deployemnt config
66+
token := deployment.ApiKey
67+
if token == "" {
68+
rawToken := req.Header.Get("Authorization")
69+
token = strings.TrimPrefix(rawToken, "Bearer ")
70+
}
71+
if token == "" {
72+
util.SendError(c, errors.New("token is empty"))
73+
return
74+
}
5575
req.Header.Set(AuthHeaderKey, token)
5676
req.Header.Del("Authorization")
5777

5878
originURL := req.URL.String()
59-
req.Host = AzureOpenAIEndpointParse.Host
60-
req.URL.Scheme = AzureOpenAIEndpointParse.Scheme
61-
req.URL.Host = AzureOpenAIEndpointParse.Host
62-
req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", deployment), strings.Replace(req.URL.Path, "/v1/", "/", 1))
63-
req.URL.RawPath = req.URL.EscapedPath()
64-
65-
query := req.URL.Query()
66-
query.Add("api-version", AzureOpenAIAPIVer)
67-
req.URL.RawQuery = query.Encode()
68-
79+
req, err = requestConverter.Convert(req, deployment)
80+
if err != nil {
81+
util.SendError(c, errors.Wrap(err, "convert request error"))
82+
return
83+
}
6984
log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String())
7085
}
7186

@@ -80,10 +95,10 @@ func Proxy(c *gin.Context) {
8095
}
8196
}
8297

83-
func GetDeploymentByModel(model string) string {
84-
if v, ok := AzureOpenAIModelMapper[model]; ok {
85-
return v
98+
func GetDeploymentByModel(model string) (*DeploymentConfig, error) {
99+
deploymentConfig, exist := ModelDeploymentConfig[model]
100+
if !exist {
101+
return nil, errors.New(fmt.Sprintf("deployment config for %s not found", model))
86102
}
87-
88-
return fallbackModelMapper.ReplaceAllString(model, "")
103+
return &deploymentConfig, nil
89104
}

cmd/router.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"github.com/gin-gonic/gin"
5+
"github.com/spf13/viper"
56
"github.com/stulzq/azure-openai-proxy/azure"
67
)
78

@@ -14,7 +15,14 @@ func registerRoute(r *gin.Engine) {
1415
r.Any("/health", func(c *gin.Context) {
1516
c.Status(200)
1617
})
17-
18-
r.Any("/v1/*path", azure.Proxy)
19-
18+
apiBase := viper.GetString("api_base")
19+
stripPrefixConverter := azure.NewStripPrefixConverter(apiBase)
20+
templateConverter := azure.NewTemplateConverter("/openai/deployments/{{.DeploymentName}}/embeddings")
21+
apiBasedRouter := r.Group(apiBase)
22+
{
23+
apiBasedRouter.Any("/engines/:model/embeddings", azure.ProxyWithConverter(templateConverter))
24+
apiBasedRouter.Any("/completions", azure.ProxyWithConverter(stripPrefixConverter))
25+
apiBasedRouter.Any("/chat/completions", azure.ProxyWithConverter(stripPrefixConverter))
26+
apiBasedRouter.Any("/embeddings", azure.ProxyWithConverter(stripPrefixConverter))
27+
}
2028
}

0 commit comments

Comments
 (0)