33import logging
44import os
55import time
6- from typing import Any , Dict , List , Optional
6+ from typing import Any , Dict , List , Optional , Tuple
77from urllib .parse import urlparse
88
9- import requests as req
10- from requests import HTTPError
9+ import requests
10+ import requests . auth
1111
1212from graphdatascience .session .algorithm_category import AlgorithmCategory
1313from graphdatascience .session .aura_api_responses import (
2323from graphdatascience .version import __version__
2424
2525
26- class AuraApi :
27- class AuraAuthToken :
28- access_token : str
29- expires_in : int
30- token_type : str
31-
32- def __init__ (self , json : Dict [str , Any ]) -> None :
33- self .access_token = json ["access_token" ]
34- expires_in : int = json ["expires_in" ]
35- self .expires_at = int (time .time ()) + expires_in
36- self .token_type = json ["token_type" ]
26+ class AuraApiError (Exception ):
27+ def __init__ (self , message : str , status_code : int ):
28+ super ().__init__ (self , message )
29+ self .status_code = status_code
3730
38- def is_expired (self ) -> bool :
39- return self .expires_at >= int (time .time ())
4031
32+ class AuraApi :
4133 def __init__ (self , client_id : str , client_secret : str , tenant_id : Optional [str ] = None ) -> None :
4234 self ._dev_env = os .environ .get ("AURA_ENV" )
4335
@@ -48,9 +40,13 @@ def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str]
4840 else :
4941 self ._base_uri = f"https://api-{ self ._dev_env } .neo4j-dev.io"
5042
51- self ._credentials = (client_id , client_secret )
52- self ._token : Optional [AuraApi .AuraAuthToken ] = None
43+ self ._auth = AuraApi .Auth (oauth_url = f"{ self ._base_uri } /oauth/token" , credentials = (client_id , client_secret ))
5344 self ._logger = logging .getLogger ()
45+
46+ self ._request_session = requests .Session ()
47+ self ._request_session .headers = {"User-agent" : f"neo4j-graphdatascience-v{ __version__ } " }
48+ self ._request_session .auth = self ._auth
49+
5450 self ._tenant_id = tenant_id if tenant_id else self ._get_tenant_id ()
5551 self ._tenant_details : Optional [TenantDetails ] = None
5652
@@ -64,36 +60,33 @@ def extract_id(uri: str) -> str:
6460 return host .split ("." )[0 ].split ("-" )[0 ]
6561
6662 def create_session (self , name : str , dbid : str , pwd : str , memory : SessionMemoryValue ) -> SessionDetails :
67- response = req .post (
63+ response = self . _request_session .post (
6864 f"{ self ._base_uri } /v1beta5/data-science/sessions" ,
69- headers = self ._build_header (),
7065 json = {"name" : name , "instance_id" : dbid , "password" : pwd , "memory" : memory .value },
7166 )
7267
73- response . raise_for_status ( )
68+ self . _check_code ( response )
7469
7570 return SessionDetails .fromJson (response .json ())
7671
7772 def list_session (self , session_id : str , dbid : str ) -> Optional [SessionDetails ]:
78- response = req .get (
73+ response = self . _request_session .get (
7974 f"{ self ._base_uri } /v1beta5/data-science/sessions/{ session_id } ?instanceId={ dbid } " ,
80- headers = self ._build_header (),
8175 )
8276
8377 if response .status_code == 404 :
8478 return None
8579
86- response . raise_for_status ( )
80+ self . _check_code ( response )
8781
8882 return SessionDetails .fromJson (response .json ())
8983
9084 def list_sessions (self , dbid : str ) -> List [SessionDetails ]:
91- response = req .get (
85+ response = self . _request_session .get (
9286 f"{ self ._base_uri } /v1beta5/data-science/sessions?instanceId={ dbid } " ,
93- headers = self ._build_header (),
9487 )
9588
96- response . raise_for_status ( )
89+ self . _check_code ( response )
9790
9891 return [SessionDetails .fromJson (s ) for s in response .json ()]
9992
@@ -127,9 +120,8 @@ def wait_for_session_running(
127120 )
128121
129122 def delete_session (self , session_id : str , dbid : str ) -> bool :
130- response = req .delete (
123+ response = self . _request_session .delete (
131124 f"{ self ._base_uri } /v1beta5/data-science/sessions/{ session_id } " ,
132- headers = self ._build_header (),
133125 json = {"instance_id" : dbid },
134126 )
135127
@@ -138,7 +130,7 @@ def delete_session(self, session_id: str, dbid: str) -> bool:
138130 elif response .status_code == 202 :
139131 return True
140132
141- response . raise_for_status ( )
133+ self . _check_code ( response )
142134
143135 return False
144136
@@ -157,56 +149,38 @@ def create_instance(
157149 "cloud_provider" : cloud_provider ,
158150 }
159151
160- response = req .post (
161- f"{ self ._base_uri } /v1/instances" ,
162- json = data ,
163- headers = self ._build_header (),
164- )
152+ response = self ._request_session .post (f"{ self ._base_uri } /v1/instances" , json = data )
165153
166- try :
167- response .raise_for_status ()
168- except HTTPError as e :
169- print (response .json ())
170- raise e
154+ self ._check_code (response )
171155
172156 return InstanceCreateDetails .from_json (response .json ()["data" ])
173157
174158 def delete_instance (self , instance_id : str ) -> Optional [InstanceSpecificDetails ]:
175- response = req .delete (
176- f"{ self ._base_uri } /v1/instances/{ instance_id } " ,
177- headers = self ._build_header (),
178- )
159+ response = self ._request_session .delete (f"{ self ._base_uri } /v1/instances/{ instance_id } " )
179160
180161 if response .status_code == 404 :
181162 return None
182163
183- response . raise_for_status ( )
164+ self . _check_code ( response )
184165
185166 return InstanceSpecificDetails .fromJson (response .json ()["data" ])
186167
187168 def list_instances (self ) -> List [InstanceDetails ]:
188- response = req .get (
189- f"{ self ._base_uri } /v1/instances" ,
190- headers = self ._build_header (),
191- params = {"tenantId" : self ._tenant_id },
192- )
169+ response = self ._request_session .get (f"{ self ._base_uri } /v1/instances" , params = {"tenantId" : self ._tenant_id })
193170
194- response . raise_for_status ( )
171+ self . _check_code ( response )
195172
196173 raw_data = response .json ()["data" ]
197174
198175 return [InstanceDetails .fromJson (i ) for i in raw_data ]
199176
200177 def list_instance (self , instance_id : str ) -> Optional [InstanceSpecificDetails ]:
201- response = req .get (
202- f"{ self ._base_uri } /v1/instances/{ instance_id } " ,
203- headers = self ._build_header (),
204- )
178+ response = self ._request_session .get (f"{ self ._base_uri } /v1/instances/{ instance_id } " )
205179
206180 if response .status_code == 404 :
207181 return None
208182
209- response . raise_for_status ( )
183+ self . _check_code ( response )
210184
211185 raw_data = response .json ()["data" ]
212186
@@ -246,17 +220,14 @@ def estimate_size(
246220 "instance_type" : "dsenterprise" ,
247221 }
248222
249- response = req . post (f"{ self ._base_uri } /v1/instances/sizing" , headers = self . _build_header () , json = data )
250- response . raise_for_status ( )
223+ response = self . _request_session . post (f"{ self ._base_uri } /v1/instances/sizing" , json = data )
224+ self . _check_code ( response )
251225
252226 return EstimationDetails .from_json (response .json ()["data" ])
253227
254228 def _get_tenant_id (self ) -> str :
255- response = req .get (
256- f"{ self ._base_uri } /v1/tenants" ,
257- headers = self ._build_header (),
258- )
259- response .raise_for_status ()
229+ response = self ._request_session .get (f"{ self ._base_uri } /v1/tenants" )
230+ self ._check_code (response )
260231
261232 raw_data = response .json ()["data" ]
262233
@@ -270,36 +241,68 @@ def _get_tenant_id(self) -> str:
270241
271242 def tenant_details (self ) -> TenantDetails :
272243 if not self ._tenant_details :
273- response = req .get (
274- f"{ self ._base_uri } /v1/tenants/{ self ._tenant_id } " ,
275- headers = self ._build_header (),
276- )
277- response .raise_for_status ()
244+ response = self ._request_session .get (f"{ self ._base_uri } /v1/tenants/{ self ._tenant_id } " )
245+ self ._check_code (response )
278246 self ._tenant_details = TenantDetails .from_json (response .json ()["data" ])
279247 return self ._tenant_details
280248
281- def _build_header (self ) -> Dict [str , str ]:
282- return {"Authorization" : f"Bearer { self ._auth_token ()} " , "User-agent" : f"neo4j-graphdatascience-v{ __version__ } " }
283-
284- def _auth_token (self ) -> str :
285- if self ._token is None or self ._token .is_expired ():
286- self ._token = self ._update_token ()
287- return self ._token .access_token
288-
289- def _update_token (self ) -> AuraAuthToken :
290- data = {
291- "grant_type" : "client_credentials" ,
292- }
293-
294- self ._logger .debug ("Updating oauth token" )
295-
296- response = req .post (
297- f"{ self ._base_uri } /oauth/token" , data = data , auth = (self ._credentials [0 ], self ._credentials [1 ])
298- )
299-
300- response .raise_for_status ()
301-
302- return AuraApi .AuraAuthToken (response .json ())
249+ def _check_code (self , resp : requests .Response ) -> None :
250+ if resp .status_code >= 400 :
251+ raise AuraApiError (
252+ f"Request for { resp .url } failed with status code { resp .status_code } - { resp .reason } : { resp .text } " ,
253+ status_code = resp .status_code ,
254+ )
303255
304256 def _instance_type (self ) -> str :
305257 return "enterprise-ds" if not self ._dev_env else "professional-ds"
258+
259+ class Auth (requests .auth .AuthBase ):
260+ class Token :
261+ access_token : str
262+ expires_in : int
263+ token_type : str
264+
265+ def __init__ (self , json : Dict [str , Any ]) -> None :
266+ self .access_token = json ["access_token" ]
267+ self .token_type = json ["token_type" ]
268+
269+ expires_in : int = json ["expires_in" ]
270+ refresh_in : int = expires_in if expires_in <= 10 else expires_in - 10
271+ # avoid token expiry during request send by refreshing 10 seconds earlier
272+ self .refresh_at = int (time .time ()) + refresh_in
273+
274+ def should_refresh (self ) -> bool :
275+ return self .refresh_at >= int (time .time ())
276+
277+ def __init__ (self , oauth_url : str , credentials : Tuple [str , str ]) -> None :
278+ self ._token : Optional [AuraApi .Auth .Token ] = None
279+ self ._logger = logging .getLogger ()
280+ self ._oauth_url = oauth_url
281+ self ._credentials = credentials
282+
283+ def __call__ (self , r : requests .PreparedRequest ) -> requests .PreparedRequest :
284+ r .headers ["Authorization" ] = f"Bearer { self ._auth_token ()} "
285+ return r
286+
287+ def _auth_token (self ) -> str :
288+ if self ._token is None or self ._token .should_refresh ():
289+ self ._token = self ._update_token ()
290+ return self ._token .access_token
291+
292+ def _update_token (self ) -> AuraApi .Auth .Token :
293+ data = {
294+ "grant_type" : "client_credentials" ,
295+ }
296+
297+ self ._logger .debug ("Updating oauth token" )
298+
299+ resp = requests .post (self ._oauth_url , data = data , auth = (self ._credentials [0 ], self ._credentials [1 ]))
300+
301+ if resp .status_code >= 400 :
302+ raise AuraApiError (
303+ "Failed to authorize with provided client credentials: "
304+ + f"{ resp .status_code } - { resp .reason } , { resp .text } " ,
305+ status_code = resp .status_code ,
306+ )
307+
308+ return AuraApi .Auth .Token (resp .json ())
0 commit comments