Skip to content

Commit 932eefe

Browse files
committed
🧱 unify azure config from env or yaml file
1 parent 8f6b8ef commit 932eefe

File tree

1 file changed

+77
-27
lines changed

1 file changed

+77
-27
lines changed

azure/init.go

Lines changed: 77 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package azure
22

33
import (
4+
"fmt"
5+
"github.com/spf13/viper"
46
"github.com/stulzq/azure-openai-proxy/constant"
7+
"github.com/stulzq/azure-openai-proxy/util"
58
"log"
69
"net/url"
7-
"os"
8-
"regexp"
910
"strings"
1011
)
1112

@@ -14,43 +15,92 @@ const (
1415
)
1516

1617
var (
17-
AzureOpenAIEndpoint = ""
18-
AzureOpenAIEndpointParse *url.URL
19-
20-
AzureOpenAIAPIVer = ""
21-
22-
AzureOpenAIModelMapper = map[string]string{
23-
"gpt-3.5-turbo": "gpt-35-turbo",
24-
}
25-
fallbackModelMapper = regexp.MustCompile(`[.:]`)
18+
C Config
19+
ModelDeploymentConfig = map[string]DeploymentConfig{}
2620
)
2721

28-
func Init() {
29-
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
30-
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
22+
func Init() error {
23+
var (
24+
apiVersion string
25+
endpoint string
26+
openaiModelMapper string
27+
err error
28+
)
3129

32-
if AzureOpenAIAPIVer == "" {
33-
AzureOpenAIAPIVer = "2023-03-15-preview"
30+
apiVersion = viper.GetString(constant.ENV_AZURE_OPENAI_API_VER)
31+
endpoint = viper.GetString(constant.ENV_AZURE_OPENAI_ENDPOINT)
32+
openaiModelMapper = viper.GetString(constant.ENV_AZURE_OPENAI_MODEL_MAPPER)
33+
if endpoint != "" && openaiModelMapper != "" {
34+
if apiVersion == "" {
35+
apiVersion = "2023-03-15-preview"
36+
}
37+
InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper)
38+
} else {
39+
if err = InitFromConfigFile(); err != nil {
40+
return err
41+
}
3442
}
3543

36-
var err error
37-
AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint)
38-
if err != nil {
39-
log.Fatal("parse AzureOpenAIEndpoint error: ", err)
44+
// ensure apiBase likes /v1
45+
apiBase := viper.GetString("api_base")
46+
if !strings.HasPrefix(apiBase, "/") {
47+
apiBase = "/" + apiBase
48+
}
49+
if strings.HasSuffix(apiBase, "/") {
50+
apiBase = apiBase[:len(apiBase)-1]
51+
}
52+
viper.Set("api_base", apiBase)
53+
log.Printf("apiBase is: %s", apiBase)
54+
for _, itemConfig := range C.DeploymentConfig {
55+
u, err := url.Parse(itemConfig.Endpoint)
56+
if err != nil {
57+
return fmt.Errorf("parse endpoint error: %w", err)
58+
}
59+
itemConfig.EndpointUrl = u
60+
ModelDeploymentConfig[itemConfig.ModelName] = itemConfig
4061
}
62+
return err
63+
}
4164

42-
if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" {
43-
for _, pair := range strings.Split(v, ",") {
65+
func InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper string) {
66+
log.Println("Init from environment variables")
67+
if openaiModelMapper != "" {
68+
// openaiModelMapper example:
69+
// gpt-3.5-turbo=deployment_name_for_gpt_model,text-davinci-003=deployment_name_for_davinci_model
70+
for _, pair := range strings.Split(openaiModelMapper, ",") {
4471
info := strings.Split(pair, "=")
4572
if len(info) != 2 {
4673
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
4774
}
48-
49-
AzureOpenAIModelMapper[info[0]] = info[1]
75+
modelName, deploymentName := info[0], info[1]
76+
ModelDeploymentConfig[modelName] = DeploymentConfig{
77+
DeploymentName: deploymentName,
78+
ModelName: modelName,
79+
Endpoint: endpoint,
80+
ApiKey: "",
81+
ApiVersion: apiVersion,
82+
}
5083
}
5184
}
85+
}
86+
87+
func InitFromConfigFile() error {
88+
log.Println("Init from config file")
89+
workDir := util.GetWorkdir()
90+
viper.SetConfigName("config")
91+
viper.SetConfigType("yaml")
92+
viper.AddConfigPath(fmt.Sprintf("%s/config", workDir))
93+
if err := viper.ReadInConfig(); err != nil {
94+
log.Printf("read config file error: %+v\n", err)
95+
return err
96+
}
5297

53-
log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
54-
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
55-
log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper)
98+
if err := viper.Unmarshal(&C); err != nil {
99+
log.Printf("unmarshal config file error: %+v\n", err)
100+
return err
101+
}
102+
for _, configItem := range C.DeploymentConfig {
103+
ModelDeploymentConfig[configItem.ModelName] = configItem
104+
}
105+
return nil
56106
}

0 commit comments

Comments
 (0)