1414from conductor .asyncio_client .http .api_response import ApiResponse
1515from conductor .asyncio_client .http .api_response import T as ApiResponseT
1616from conductor .asyncio_client .http .exceptions import ApiException
17+ from conductor .client .exceptions .auth_401_policy import Auth401Policy , Auth401Handler
1718
1819logger = logging .getLogger (Configuration .get_logging_formatted_name (__name__ ))
1920
2021
2122class ApiClientAdapter (ApiClient ):
22- def __init__ (self , * args , ** kwargs ):
23+ def __init__ (
24+ self , configuration = None , header_name = None , header_value = None , cookie = None
25+ ):
2326 self ._token_lock = asyncio .Lock ()
24- super ().__init__ (* args , ** kwargs )
27+ self .configuration = configuration or Configuration ()
28+
29+ self .rest_client = rest .RESTClientObject (self .configuration )
30+ self .default_headers = {}
31+ if header_name is not None :
32+ self .default_headers [header_name ] = header_value
33+ self .cookie = cookie
34+
35+ # Initialize 401 policy handler
36+ auth_401_policy = Auth401Policy (
37+ max_attempts = self .configuration .auth_401_max_attempts ,
38+ base_delay_ms = self .configuration .auth_401_base_delay_ms ,
39+ max_delay_ms = self .configuration .auth_401_max_delay_ms ,
40+ jitter_percent = self .configuration .auth_401_jitter_percent ,
41+ stop_behavior = self .configuration .auth_401_stop_behavior ,
42+ )
43+ self .auth_401_handler = Auth401Handler (auth_401_policy )
2544
2645 async def call_api (
2746 self ,
@@ -46,7 +65,10 @@ async def call_api(
4665
4766 try :
4867 logger .debug (
49- "HTTP request method: %s; url: %s; header_params: %s" , method , url , header_params
68+ "HTTP request method: %s; url: %s; header_params: %s" ,
69+ method ,
70+ url ,
71+ header_params ,
5072 )
5173 response_data = await self .rest_client .request (
5274 method ,
@@ -56,41 +78,133 @@ async def call_api(
5678 post_params = post_params ,
5779 _request_timeout = _request_timeout ,
5880 )
59- if (
81+
82+ # Handle 401 retries with policy-based logic
83+ resource_path = url .replace (self .configuration .host , "" )
84+
85+ # Loop to handle multiple 401 retries
86+ while (
6087 response_data .status == 401 # noqa: PLR2004 (Unauthorized status code)
6188 and url != self .configuration .host + "/token"
6289 ):
63- logger .warning (
64- "HTTP response from: %s; status code: 401 - obtaining new token" , url
65- )
66- async with self ._token_lock :
67- # The lock is intentionally broad (covers the whole block including the token state)
68- # to avoid race conditions: without it, other coroutines could mis-evaluate
69- # token state during a context switch and trigger redundant refreshes
70- token_expired = (
71- self .configuration .token_update_time > 0
72- and time .time ()
73- >= self .configuration .token_update_time
74- + self .configuration .auth_token_ttl_sec
90+ # Check if this is an auth-dependent call that should trigger 401 policy
91+ if self .auth_401_handler .policy .is_auth_dependent_call (
92+ resource_path , method
93+ ):
94+ # Handle 401 with policy (exponential backoff, max attempts, etc.)
95+ result = self .auth_401_handler .handle_401_error (
96+ resource_path = resource_path ,
97+ method = method ,
98+ status_code = 401 ,
99+ error_code = None ,
75100 )
76- invalid_token = not self .configuration ._http_config .api_key .get ("api_key" )
77101
78- if invalid_token or token_expired :
79- token = await self .refresh_authorization_token ()
102+ if result ["should_retry" ]:
103+ # Apply exponential backoff delay
104+ if result ["delay_seconds" ] > 0 :
105+ logger .info (
106+ "401 error on %s %s - waiting %.2fs before retry (attempt %d/%d)" ,
107+ method ,
108+ url ,
109+ result ["delay_seconds" ],
110+ result ["attempt_count" ],
111+ result ["max_attempts" ],
112+ )
113+ await asyncio .sleep (result ["delay_seconds" ])
114+
115+ # Try to refresh token and retry
116+ async with self ._token_lock :
117+ # Check if token was already refreshed by another coroutine
118+ # to avoid race condition where multiple concurrent 401s
119+ # trigger redundant token refreshes
120+ token_expired = (
121+ self .configuration .token_update_time > 0
122+ and time .time ()
123+ >= self .configuration .token_update_time
124+ + self .configuration .auth_token_ttl_sec
125+ )
126+ invalid_token = (
127+ not self .configuration ._http_config .api_key .get (
128+ "api_key"
129+ )
130+ )
131+
132+ if invalid_token or token_expired :
133+ token = await self .refresh_authorization_token ()
134+ else :
135+ token = self .configuration ._http_config .api_key [
136+ "api_key"
137+ ]
138+ if header_params is None :
139+ header_params = {}
140+ header_params ["X-Authorization" ] = token
141+
142+ # Make the retry request outside the lock to avoid blocking other coroutines
143+ response_data = await self .rest_client .request (
144+ method ,
145+ url ,
146+ headers = header_params ,
147+ body = body ,
148+ post_params = post_params ,
149+ _request_timeout = _request_timeout ,
150+ )
80151 else :
81- token = self .configuration ._http_config .api_key ["api_key" ]
82- header_params ["X-Authorization" ] = token
83- response_data = await self .rest_client .request (
84- method ,
152+ # Max attempts reached - log error and break
153+ logger .error (
154+ "401 error on %s %s - max attempts (%d) reached, stopping worker" ,
155+ method ,
156+ url ,
157+ result ["max_attempts" ],
158+ )
159+ break
160+ else :
161+ # Non-auth-dependent call with 401 - use original behavior (single retry)
162+ logger .warning (
163+ "HTTP response from: %s; status code: 401 - obtaining new token" ,
85164 url ,
86- headers = header_params ,
87- body = body ,
88- post_params = post_params ,
89- _request_timeout = _request_timeout ,
90165 )
166+ async with self ._token_lock :
167+ # The lock is intentionally broad (covers the whole block including the token state)
168+ # to avoid race conditions: without it, other coroutines could mis-evaluate
169+ # token state during a context switch and trigger redundant refreshes
170+ token_expired = (
171+ self .configuration .token_update_time > 0
172+ and time .time ()
173+ >= self .configuration .token_update_time
174+ + self .configuration .auth_token_ttl_sec
175+ )
176+ invalid_token = not self .configuration ._http_config .api_key .get (
177+ "api_key"
178+ )
179+
180+ if invalid_token or token_expired :
181+ token = await self .refresh_authorization_token ()
182+ else :
183+ token = self .configuration ._http_config .api_key ["api_key" ]
184+ if header_params is None :
185+ header_params = {}
186+ header_params ["X-Authorization" ] = token
187+ response_data = await self .rest_client .request (
188+ method ,
189+ url ,
190+ headers = header_params ,
191+ body = body ,
192+ post_params = post_params ,
193+ _request_timeout = _request_timeout ,
194+ )
195+ # Break after single retry for non-auth-dependent calls
196+ break
197+
198+ # Record successful call to reset 401 attempt counters
199+ if response_data .status != 401 :
200+ self .auth_401_handler .record_successful_call (resource_path )
201+
91202 except ApiException as e :
92203 logger .error (
93- "HTTP request failed url: %s status: %s; reason: %s" , url , e .status , e .reason
204+ "HTTP request failed url: %s status: %s; reason: %s" ,
205+ url ,
206+ e .status ,
207+ e .reason ,
94208 )
95209 raise e
96210
@@ -117,7 +231,9 @@ def response_deserialize(
117231 and 100 <= response_data .status <= 599 # noqa: PLR2004
118232 ):
119233 # if not found, look for '1XX', '2XX', etc.
120- response_type = response_types_map .get (str (response_data .status )[0 ] + "XX" , None )
234+ response_type = response_types_map .get (
235+ str (response_data .status )[0 ] + "XX" , None
236+ )
121237
122238 # deserialize response data
123239 response_text = None
@@ -134,10 +250,14 @@ def response_deserialize(
134250 match = re .search (r"charset=([a-zA-Z\-\d]+)[\s;]?" , content_type )
135251 encoding = match .group (1 ) if match else "utf-8"
136252 response_text = response_data .data .decode (encoding )
137- return_data = self .deserialize (response_text , response_type , content_type )
253+ return_data = self .deserialize (
254+ response_text , response_type , content_type
255+ )
138256 finally :
139257 if not 200 <= response_data .status <= 299 : # noqa: PLR2004
140- logger .error ("Unexpected response status code: %s" , response_data .status )
258+ logger .error (
259+ "Unexpected response status code: %s" , response_data .status
260+ )
141261 raise ApiException .from_response (
142262 http_resp = response_data ,
143263 body = response_text ,
0 commit comments