1+ import enum
12import os
23from pathlib import Path
34import random
45import re
56from typing import (
6- Any ,
77 Callable ,
88 Iterable ,
99 List ,
1010 Mapping ,
1111 Optional ,
1212 Sequence ,
1313 Tuple ,
14+ TypeVar ,
1415 Union ,
1516 cast ,
1617)
2829 'MAX_INFLIGHT_CHUNKS' ,
2930]
3031
32+
33+ class Undefined (enum .Enum ):
34+ token = object ()
35+
36+
3137_config = None
32- _undefined = object ()
38+ _undefined = Undefined . token
3339
3440API_VERSION = (6 , '20200815' )
3541
@@ -47,8 +53,19 @@ def parse_api_version(value: str) -> Tuple[int, str]:
4753 raise ValueError ('Could not parse the given API version string' , value )
4854
4955
50- def get_env (key : str , default : Any = _undefined , * ,
51- clean : Callable [[str ], Any ] = lambda v : v ):
56+ T = TypeVar ('T' )
57+
58+
59+ def default_clean (v : str ) -> T :
60+ return cast (T , v )
61+
62+
63+ def get_env (
64+ key : str ,
65+ default : Union [str , Undefined ] = _undefined ,
66+ * ,
67+ clean : Callable [[str ], T ] = default_clean ,
68+ ) -> T :
5269 """
5370 Retrieves a configuration value from the environment variables.
5471 The given *key* is uppercased and prefixed by ``"BACKEND_"`` and then
@@ -64,14 +81,14 @@ def get_env(key: str, default: Any = _undefined, *,
6481 :returns: The value processed by the *clean* function.
6582 """
6683 key = key .upper ()
67- v = os .environ .get ('BACKEND_' + key )
68- if v is None :
69- v = os .environ .get ('SORNA_' + key )
70- if v is None :
84+ raw = os .environ .get ('BACKEND_' + key )
85+ if raw is None :
86+ raw = os .environ .get ('SORNA_' + key )
87+ if raw is None :
7188 if default is _undefined :
7289 raise KeyError (key )
73- v = default
74- return clean (v )
90+ raw = default
91+ return clean (raw )
7592
7693
7794def bool_env (v : str ) -> bool :
@@ -86,8 +103,8 @@ def bool_env(v: str) -> bool:
86103def _clean_urls (v : Union [URL , str ]) -> List [URL ]:
87104 if isinstance (v , URL ):
88105 return [v ]
106+ urls = []
89107 if isinstance (v , str ):
90- urls = []
91108 for entry in v .split (',' ):
92109 url = URL (entry )
93110 if not url .is_absolute ():
@@ -96,12 +113,10 @@ def _clean_urls(v: Union[URL, str]) -> List[URL]:
96113 return urls
97114
98115
99- def _clean_tokens (v ):
100- if isinstance (v , str ):
101- if not v :
102- return tuple ()
103- return tuple (v .split (',' ))
104- return tuple (iter (v ))
116+ def _clean_tokens (v : str ) -> Tuple [str , ...]:
117+ if not v :
118+ return tuple ()
119+ return tuple (v .split (',' ))
105120
106121
107122class APIConfig :
@@ -141,21 +156,22 @@ class APIConfig:
141156 <ai.backend.client.kernel.Kernel.get_or_create>` calls.
142157 """
143158
144- DEFAULTS : Mapping [str , Any ] = {
159+ DEFAULTS : Mapping [str , str ] = {
145160 'endpoint' : 'https://api.backend.ai' ,
146161 'endpoint_type' : 'api' ,
147162 'version' : f'v{ API_VERSION [0 ]} .{ API_VERSION [1 ]} ' ,
148163 'hash_type' : 'sha256' ,
149164 'domain' : 'default' ,
150165 'group' : 'default' ,
151- 'connection_timeout' : 10.0 ,
152- 'read_timeout' : None ,
166+ 'connection_timeout' : ' 10.0' ,
167+ 'read_timeout' : '0' ,
153168 }
154169 """
155170 The default values for config parameterse settable via environment variables
156171 xcept the access and secret keys.
157172 """
158173
174+ _endpoints : List [URL ]
159175 _group : str
160176 _hash_type : str
161177
@@ -179,35 +195,39 @@ def __init__(
179195 from . import get_user_agent
180196 self ._endpoints = (
181197 _clean_urls (endpoint ) if endpoint else
182- get_env ('ENDPOINT' , self .DEFAULTS ['endpoint' ], clean = _clean_urls ))
198+ get_env ('ENDPOINT' , self .DEFAULTS ['endpoint' ], clean = _clean_urls )
199+ )
183200 random .shuffle (self ._endpoints )
184- self ._endpoint_type = endpoint_type if endpoint_type is not None \
185- else get_env ('ENDPOINT_TYPE' , self .DEFAULTS ['endpoint_type' ])
186- self ._domain = domain if domain is not None else get_env ('DOMAIN' , self .DEFAULTS ['domain' ])
187- self ._group = group if group is not None else get_env ('GROUP' , self .DEFAULTS ['group' ])
188- self ._version = version if version is not None else self .DEFAULTS ['version' ]
201+ self ._endpoint_type = endpoint_type if endpoint_type is not None else \
202+ get_env ('ENDPOINT_TYPE' , self .DEFAULTS ['endpoint_type' ], clean = str )
203+ self ._domain = domain if domain is not None else \
204+ get_env ('DOMAIN' , self .DEFAULTS ['domain' ], clean = str )
205+ self ._group = group if group is not None else \
206+ get_env ('GROUP' , self .DEFAULTS ['group' ], clean = str )
207+ self ._version = version if version is not None else \
208+ self .DEFAULTS ['version' ]
189209 self ._user_agent = user_agent if user_agent is not None else get_user_agent ()
190210 if self ._endpoint_type == 'api' :
191- self ._access_key = access_key if access_key is not None \
192- else get_env ('ACCESS_KEY' , '' )
193- self ._secret_key = secret_key if secret_key is not None \
194- else get_env ('SECRET_KEY' , '' )
211+ self ._access_key = access_key if access_key is not None else \
212+ get_env ('ACCESS_KEY' , '' )
213+ self ._secret_key = secret_key if secret_key is not None else \
214+ get_env ('SECRET_KEY' , '' )
195215 else :
196216 self ._access_key = 'dummy'
197217 self ._secret_key = 'dummy'
198218 self ._hash_type = hash_type .lower () if hash_type is not None else \
199219 cast (str , self .DEFAULTS ['hash_type' ])
200220 arg_vfolders = set (vfolder_mounts ) if vfolder_mounts else set ()
201- env_vfolders = set (get_env ('VFOLDER_MOUNTS' , [] , clean = _clean_tokens ))
221+ env_vfolders = set (get_env ('VFOLDER_MOUNTS' , '' , clean = _clean_tokens ))
202222 self ._vfolder_mounts = [* (arg_vfolders | env_vfolders )]
203223 # prefer the argument flag and fallback to env if the flag is not set.
204224 self ._skip_sslcert_validation = (skip_sslcert_validation
205225 if skip_sslcert_validation else
206226 get_env ('SKIP_SSLCERT_VALIDATION' , 'no' , clean = bool_env ))
207227 self ._connection_timeout = connection_timeout if connection_timeout else \
208- get_env ('CONNECTION_TIMEOUT' , self .DEFAULTS ['connection_timeout' ])
228+ get_env ('CONNECTION_TIMEOUT' , self .DEFAULTS ['connection_timeout' ], clean = float )
209229 self ._read_timeout = read_timeout if read_timeout else \
210- get_env ('READ_TIMEOUT' , self .DEFAULTS ['read_timeout' ])
230+ get_env ('READ_TIMEOUT' , self .DEFAULTS ['read_timeout' ], clean = float )
211231 self ._announcement_handler = announcement_handler
212232
213233 @property
@@ -233,6 +253,9 @@ def rotate_endpoints(self):
233253 item = self ._endpoints .pop (0 )
234254 self ._endpoints .append (item )
235255
256+ def load_balance_endpoints (self ):
257+ pass
258+
236259 @property
237260 def endpoint_type (self ) -> str :
238261 """
0 commit comments