11package azure
22
33import (
4+ "fmt"
5+ "github.com/spf13/viper"
46 "github.com/stulzq/azure-openai-proxy/constant"
7+ "github.com/stulzq/azure-openai-proxy/util"
58 "log"
69 "net/url"
7- "os"
8- "regexp"
910 "strings"
1011)
1112
@@ -14,43 +15,92 @@ const (
1415)
1516
1617var (
17- AzureOpenAIEndpoint = ""
18- AzureOpenAIEndpointParse * url.URL
19-
20- AzureOpenAIAPIVer = ""
21-
22- AzureOpenAIModelMapper = map [string ]string {
23- "gpt-3.5-turbo" : "gpt-35-turbo" ,
24- }
25- fallbackModelMapper = regexp .MustCompile (`[.:]` )
18+ C Config
19+ ModelDeploymentConfig = map [string ]DeploymentConfig {}
2620)
2721
28- func Init () {
29- AzureOpenAIAPIVer = os .Getenv (constant .ENV_AZURE_OPENAI_API_VER )
30- AzureOpenAIEndpoint = os .Getenv (constant .ENV_AZURE_OPENAI_ENDPOINT )
22+ func Init () error {
23+ var (
24+ apiVersion string
25+ endpoint string
26+ openaiModelMapper string
27+ err error
28+ )
3129
32- if AzureOpenAIAPIVer == "" {
33- AzureOpenAIAPIVer = "2023-03-15-preview"
30+ apiVersion = viper .GetString (constant .ENV_AZURE_OPENAI_API_VER )
31+ endpoint = viper .GetString (constant .ENV_AZURE_OPENAI_ENDPOINT )
32+ openaiModelMapper = viper .GetString (constant .ENV_AZURE_OPENAI_MODEL_MAPPER )
33+ if endpoint != "" && openaiModelMapper != "" {
34+ if apiVersion == "" {
35+ apiVersion = "2023-03-15-preview"
36+ }
37+ InitFromEnvironmentVariables (apiVersion , endpoint , openaiModelMapper )
38+ } else {
39+ if err = InitFromConfigFile (); err != nil {
40+ return err
41+ }
3442 }
3543
36- var err error
37- AzureOpenAIEndpointParse , err = url .Parse (AzureOpenAIEndpoint )
38- if err != nil {
39- log .Fatal ("parse AzureOpenAIEndpoint error: " , err )
44+ // ensure apiBase likes /v1
45+ apiBase := viper .GetString ("api_base" )
46+ if ! strings .HasPrefix (apiBase , "/" ) {
47+ apiBase = "/" + apiBase
48+ }
49+ if strings .HasSuffix (apiBase , "/" ) {
50+ apiBase = apiBase [:len (apiBase )- 1 ]
51+ }
52+ viper .Set ("api_base" , apiBase )
53+ log .Printf ("apiBase is: %s" , apiBase )
54+ for _ , itemConfig := range C .DeploymentConfig {
55+ u , err := url .Parse (itemConfig .Endpoint )
56+ if err != nil {
57+ return fmt .Errorf ("parse endpoint error: %w" , err )
58+ }
59+ itemConfig .EndpointUrl = u
60+ ModelDeploymentConfig [itemConfig .ModelName ] = itemConfig
4061 }
62+ return err
63+ }
4164
42- if v := os .Getenv (constant .ENV_AZURE_OPENAI_MODEL_MAPPER ); v != "" {
43- for _ , pair := range strings .Split (v , "," ) {
65+ func InitFromEnvironmentVariables (apiVersion , endpoint , openaiModelMapper string ) {
66+ log .Println ("Init from environment variables" )
67+ if openaiModelMapper != "" {
68+ // openaiModelMapper example:
69+ // gpt-3.5-turbo=deployment_name_for_gpt_model,text-davinci-003=deployment_name_for_davinci_model
70+ for _ , pair := range strings .Split (openaiModelMapper , "," ) {
4471 info := strings .Split (pair , "=" )
4572 if len (info ) != 2 {
4673 log .Fatalf ("error parsing %s, invalid value %s" , constant .ENV_AZURE_OPENAI_MODEL_MAPPER , pair )
4774 }
48-
49- AzureOpenAIModelMapper [info [0 ]] = info [1 ]
75+ modelName , deploymentName := info [0 ], info [1 ]
76+ ModelDeploymentConfig [modelName ] = DeploymentConfig {
77+ DeploymentName : deploymentName ,
78+ ModelName : modelName ,
79+ Endpoint : endpoint ,
80+ ApiKey : "" ,
81+ ApiVersion : apiVersion ,
82+ }
5083 }
5184 }
85+ }
86+
87+ func InitFromConfigFile () error {
88+ log .Println ("Init from config file" )
89+ workDir := util .GetWorkdir ()
90+ viper .SetConfigName ("config" )
91+ viper .SetConfigType ("yaml" )
92+ viper .AddConfigPath (fmt .Sprintf ("%s/config" , workDir ))
93+ if err := viper .ReadInConfig (); err != nil {
94+ log .Printf ("read config file error: %+v\n " , err )
95+ return err
96+ }
5297
53- log .Println ("AzureOpenAIAPIVer: " , AzureOpenAIAPIVer )
54- log .Println ("AzureOpenAIEndpoint: " , AzureOpenAIEndpoint )
55- log .Println ("AzureOpenAIModelMapper: " , AzureOpenAIModelMapper )
98+ if err := viper .Unmarshal (& C ); err != nil {
99+ log .Printf ("unmarshal config file error: %+v\n " , err )
100+ return err
101+ }
102+ for _ , configItem := range C .DeploymentConfig {
103+ ModelDeploymentConfig [configItem .ModelName ] = configItem
104+ }
105+ return nil
56106}
0 commit comments