33import logging
44import os
55import time
6- from typing import Any , Dict , List , Optional
6+ from typing import Any , Dict , List , Optional , Tuple
77from urllib .parse import urlparse
88
99import requests
10+ import requests .auth
1011
1112from graphdatascience .session .algorithm_category import AlgorithmCategory
1213from graphdatascience .session .aura_api_responses import (
@@ -29,20 +30,6 @@ def __init__(self, message: str, status_code: int):
2930
3031
3132class AuraApi :
32- class AuraAuthToken :
33- access_token : str
34- expires_in : int
35- token_type : str
36-
37- def __init__ (self , json : Dict [str , Any ]) -> None :
38- self .access_token = json ["access_token" ]
39- expires_in : int = json ["expires_in" ]
40- self .expires_at = int (time .time ()) + expires_in
41- self .token_type = json ["token_type" ]
42-
43- def is_expired (self ) -> bool :
44- return self .expires_at >= int (time .time ())
45-
4633 def __init__ (self , client_id : str , client_secret : str , tenant_id : Optional [str ] = None ) -> None :
4734 self ._dev_env = os .environ .get ("AURA_ENV" )
4835
@@ -53,12 +40,13 @@ def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str]
5340 else :
5441 self ._base_uri = f"https://api-{ self ._dev_env } .neo4j-dev.io"
5542
56- self ._credentials = (client_id , client_secret )
57- self ._token : Optional [AuraApi .AuraAuthToken ] = None
43+ self ._auth = AuraApi .Auth (oauth_url = f"{ self ._base_uri } /oauth/token" , credentials = (client_id , client_secret ))
5844 self ._logger = logging .getLogger ()
5945 self ._tenant_id = tenant_id if tenant_id else self ._get_tenant_id ()
6046 self ._tenant_details : Optional [TenantDetails ] = None
6147 self ._request_session = requests .Session ()
48+ self ._request_session .headers = {"User-agent" : f"neo4j-graphdatascience-v{ __version__ } " }
49+ self ._request_session .auth = self ._auth
6250
6351 @staticmethod
6452 def extract_id (uri : str ) -> str :
@@ -72,7 +60,6 @@ def extract_id(uri: str) -> str:
7260 def create_session (self , name : str , dbid : str , pwd : str , memory : SessionMemoryValue ) -> SessionDetails :
7361 response = self ._request_session .post (
7462 f"{ self ._base_uri } /v1beta5/data-science/sessions" ,
75- headers = self ._build_header (),
7663 json = {"name" : name , "instance_id" : dbid , "password" : pwd , "memory" : memory .value },
7764 )
7865
@@ -83,7 +70,6 @@ def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryVa
8370 def list_session (self , session_id : str , dbid : str ) -> Optional [SessionDetails ]:
8471 response = self ._request_session .get (
8572 f"{ self ._base_uri } /v1beta5/data-science/sessions/{ session_id } ?instanceId={ dbid } " ,
86- headers = self ._build_header (),
8773 )
8874
8975 if response .status_code == 404 :
@@ -96,7 +82,6 @@ def list_session(self, session_id: str, dbid: str) -> Optional[SessionDetails]:
9682 def list_sessions (self , dbid : str ) -> List [SessionDetails ]:
9783 response = self ._request_session .get (
9884 f"{ self ._base_uri } /v1beta5/data-science/sessions?instanceId={ dbid } " ,
99- headers = self ._build_header (),
10085 )
10186
10287 self ._check_code (response )
@@ -135,7 +120,6 @@ def wait_for_session_running(
135120 def delete_session (self , session_id : str , dbid : str ) -> bool :
136121 response = self ._request_session .delete (
137122 f"{ self ._base_uri } /v1beta5/data-science/sessions/{ session_id } " ,
138- headers = self ._build_header (),
139123 json = {"instance_id" : dbid },
140124 )
141125
@@ -163,21 +147,14 @@ def create_instance(
163147 "cloud_provider" : cloud_provider ,
164148 }
165149
166- response = self ._request_session .post (
167- f"{ self ._base_uri } /v1/instances" ,
168- json = data ,
169- headers = self ._build_header (),
170- )
150+ response = self ._request_session .post (f"{ self ._base_uri } /v1/instances" , json = data )
171151
172152 self ._check_code (response )
173153
174154 return InstanceCreateDetails .from_json (response .json ()["data" ])
175155
176156 def delete_instance (self , instance_id : str ) -> Optional [InstanceSpecificDetails ]:
177- response = self ._request_session .delete (
178- f"{ self ._base_uri } /v1/instances/{ instance_id } " ,
179- headers = self ._build_header (),
180- )
157+ response = self ._request_session .delete (f"{ self ._base_uri } /v1/instances/{ instance_id } " )
181158
182159 if response .status_code == 404 :
183160 return None
@@ -187,11 +164,7 @@ def delete_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]
187164 return InstanceSpecificDetails .fromJson (response .json ()["data" ])
188165
189166 def list_instances (self ) -> List [InstanceDetails ]:
190- response = self ._request_session .get (
191- f"{ self ._base_uri } /v1/instances" ,
192- headers = self ._build_header (),
193- params = {"tenantId" : self ._tenant_id },
194- )
167+ response = self ._request_session .get (f"{ self ._base_uri } /v1/instances" , params = {"tenantId" : self ._tenant_id })
195168
196169 self ._check_code (response )
197170
@@ -200,10 +173,7 @@ def list_instances(self) -> List[InstanceDetails]:
200173 return [InstanceDetails .fromJson (i ) for i in raw_data ]
201174
202175 def list_instance (self , instance_id : str ) -> Optional [InstanceSpecificDetails ]:
203- response = self ._request_session .get (
204- f"{ self ._base_uri } /v1/instances/{ instance_id } " ,
205- headers = self ._build_header (),
206- )
176+ response = self ._request_session .get (f"{ self ._base_uri } /v1/instances/{ instance_id } " )
207177
208178 if response .status_code == 404 :
209179 return None
@@ -248,18 +218,13 @@ def estimate_size(
248218 "instance_type" : "dsenterprise" ,
249219 }
250220
251- response = self ._request_session .post (
252- f"{ self ._base_uri } /v1/instances/sizing" , headers = self ._build_header (), json = data
253- )
221+ response = self ._request_session .post (f"{ self ._base_uri } /v1/instances/sizing" , json = data )
254222 self ._check_code (response )
255223
256224 return EstimationDetails .from_json (response .json ()["data" ])
257225
258226 def _get_tenant_id (self ) -> str :
259- response = self ._request_session .get (
260- f"{ self ._base_uri } /v1/tenants" ,
261- headers = self ._build_header (),
262- )
227+ response = self ._request_session .get (f"{ self ._base_uri } /v1/tenants" )
263228 self ._check_code (response )
264229
265230 raw_data = response .json ()["data" ]
@@ -274,37 +239,11 @@ def _get_tenant_id(self) -> str:
274239
275240 def tenant_details (self ) -> TenantDetails :
276241 if not self ._tenant_details :
277- response = self ._request_session .get (
278- f"{ self ._base_uri } /v1/tenants/{ self ._tenant_id } " ,
279- headers = self ._build_header (),
280- )
242+ response = self ._request_session .get (f"{ self ._base_uri } /v1/tenants/{ self ._tenant_id } " )
281243 self ._check_code (response )
282244 self ._tenant_details = TenantDetails .from_json (response .json ()["data" ])
283245 return self ._tenant_details
284246
285- def _build_header (self ) -> Dict [str , str ]:
286- return {"Authorization" : f"Bearer { self ._auth_token ()} " , "User-agent" : f"neo4j-graphdatascience-v{ __version__ } " }
287-
288- def _auth_token (self ) -> str :
289- if self ._token is None or self ._token .is_expired ():
290- self ._token = self ._update_token ()
291- return self ._token .access_token
292-
293- def _update_token (self ) -> AuraAuthToken :
294- data = {
295- "grant_type" : "client_credentials" ,
296- }
297-
298- self ._logger .debug ("Updating oauth token" )
299-
300- response = self ._request_session .post (
301- f"{ self ._base_uri } /oauth/token" , data = data , auth = (self ._credentials [0 ], self ._credentials [1 ])
302- )
303-
304- self ._check_code (response )
305-
306- return AuraApi .AuraAuthToken (response .json ())
307-
308247 def _check_code (self , resp : requests .Response ) -> None :
309248 if resp .status_code >= 400 :
310249 raise AuraApiError (
@@ -314,3 +253,52 @@ def _check_code(self, resp: requests.Response) -> None:
314253
315254 def _instance_type (self ) -> str :
316255 return "enterprise-ds" if not self ._dev_env else "professional-ds"
256+
257+ class Auth (requests .auth .AuthBase ):
258+ class Token :
259+ access_token : str
260+ expires_in : int
261+ token_type : str
262+
263+ def __init__ (self , json : Dict [str , Any ]) -> None :
264+ self .access_token = json ["access_token" ]
265+ expires_in : int = json ["expires_in" ]
266+ self .expires_at = int (time .time ()) + expires_in
267+ self .token_type = json ["token_type" ]
268+
269+ # TODO add a buffer of 10s to avoid nearly expiring tokens
270+ def is_expired (self ) -> bool :
271+ return self .expires_at >= int (time .time ())
272+
273+ def __init__ (self , oauth_url : str , credentials : Tuple [str , str ]) -> None :
274+ self ._token : Optional [AuraApi .Auth .Token ] = None
275+ self ._logger = logging .getLogger ()
276+ self ._oauth_url = oauth_url
277+ self ._credentials = credentials
278+
279+ def __call__ (self , r : requests .PreparedRequest ) -> requests .PreparedRequest :
280+ r .headers ["Authorization" ] = f"Bearer { self ._auth_token ()} "
281+ return r
282+
283+ def _auth_token (self ) -> str :
284+ if self ._token is None or self ._token .is_expired ():
285+ self ._token = self ._update_token ()
286+ return self ._token .access_token
287+
288+ def _update_token (self ) -> AuraApi .Auth .Token :
289+ data = {
290+ "grant_type" : "client_credentials" ,
291+ }
292+
293+ self ._logger .debug ("Updating oauth token" )
294+
295+ resp = requests .post (self ._oauth_url , data = data , auth = (self ._credentials [0 ], self ._credentials [1 ]))
296+
297+ if resp .status_code >= 400 :
298+ raise AuraApiError (
299+ "Failed to authorize with provided client credentials: "
300+ + f"{ resp .status_code } - { resp .reason } , { resp .text } " ,
301+ status_code = resp .status_code ,
302+ )
303+
304+ return AuraApi .Auth .Token (resp .json ())
0 commit comments