Skip to content

Commit d42480b

Browse files
authored
Merge pull request #682 from FlorentinD/fix-gds-version-init
Improve Aura-API error message / Refactor AuraAPI
2 parents bae3e7f + 0bc09f3 commit d42480b

File tree

4 files changed

+128
-109
lines changed

4 files changed

+128
-109
lines changed

graphdatascience/session/aura_api.py

Lines changed: 93 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import logging
44
import os
55
import time
6-
from typing import Any, Dict, List, Optional
6+
from typing import Any, Dict, List, Optional, Tuple
77
from urllib.parse import urlparse
88

9-
import requests as req
10-
from requests import HTTPError
9+
import requests
10+
import requests.auth
1111

1212
from graphdatascience.session.algorithm_category import AlgorithmCategory
1313
from graphdatascience.session.aura_api_responses import (
@@ -23,21 +23,13 @@
2323
from graphdatascience.version import __version__
2424

2525

26-
class AuraApi:
27-
class AuraAuthToken:
28-
access_token: str
29-
expires_in: int
30-
token_type: str
31-
32-
def __init__(self, json: Dict[str, Any]) -> None:
33-
self.access_token = json["access_token"]
34-
expires_in: int = json["expires_in"]
35-
self.expires_at = int(time.time()) + expires_in
36-
self.token_type = json["token_type"]
26+
class AuraApiError(Exception):
27+
def __init__(self, message: str, status_code: int):
28+
super().__init__(self, message)
29+
self.status_code = status_code
3730

38-
def is_expired(self) -> bool:
39-
return self.expires_at >= int(time.time())
4031

32+
class AuraApi:
4133
def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str] = None) -> None:
4234
self._dev_env = os.environ.get("AURA_ENV")
4335

@@ -48,9 +40,13 @@ def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str]
4840
else:
4941
self._base_uri = f"https://api-{self._dev_env}.neo4j-dev.io"
5042

51-
self._credentials = (client_id, client_secret)
52-
self._token: Optional[AuraApi.AuraAuthToken] = None
43+
self._auth = AuraApi.Auth(oauth_url=f"{self._base_uri}/oauth/token", credentials=(client_id, client_secret))
5344
self._logger = logging.getLogger()
45+
46+
self._request_session = requests.Session()
47+
self._request_session.headers = {"User-agent": f"neo4j-graphdatascience-v{__version__}"}
48+
self._request_session.auth = self._auth
49+
5450
self._tenant_id = tenant_id if tenant_id else self._get_tenant_id()
5551
self._tenant_details: Optional[TenantDetails] = None
5652

@@ -64,36 +60,33 @@ def extract_id(uri: str) -> str:
6460
return host.split(".")[0].split("-")[0]
6561

6662
def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryValue) -> SessionDetails:
67-
response = req.post(
63+
response = self._request_session.post(
6864
f"{self._base_uri}/v1beta5/data-science/sessions",
69-
headers=self._build_header(),
7065
json={"name": name, "instance_id": dbid, "password": pwd, "memory": memory.value},
7166
)
7267

73-
response.raise_for_status()
68+
self._check_code(response)
7469

7570
return SessionDetails.fromJson(response.json())
7671

7772
def list_session(self, session_id: str, dbid: str) -> Optional[SessionDetails]:
78-
response = req.get(
73+
response = self._request_session.get(
7974
f"{self._base_uri}/v1beta5/data-science/sessions/{session_id}?instanceId={dbid}",
80-
headers=self._build_header(),
8175
)
8276

8377
if response.status_code == 404:
8478
return None
8579

86-
response.raise_for_status()
80+
self._check_code(response)
8781

8882
return SessionDetails.fromJson(response.json())
8983

9084
def list_sessions(self, dbid: str) -> List[SessionDetails]:
91-
response = req.get(
85+
response = self._request_session.get(
9286
f"{self._base_uri}/v1beta5/data-science/sessions?instanceId={dbid}",
93-
headers=self._build_header(),
9487
)
9588

96-
response.raise_for_status()
89+
self._check_code(response)
9790

9891
return [SessionDetails.fromJson(s) for s in response.json()]
9992

@@ -127,9 +120,8 @@ def wait_for_session_running(
127120
)
128121

129122
def delete_session(self, session_id: str, dbid: str) -> bool:
130-
response = req.delete(
123+
response = self._request_session.delete(
131124
f"{self._base_uri}/v1beta5/data-science/sessions/{session_id}",
132-
headers=self._build_header(),
133125
json={"instance_id": dbid},
134126
)
135127

@@ -138,7 +130,7 @@ def delete_session(self, session_id: str, dbid: str) -> bool:
138130
elif response.status_code == 202:
139131
return True
140132

141-
response.raise_for_status()
133+
self._check_code(response)
142134

143135
return False
144136

@@ -157,56 +149,38 @@ def create_instance(
157149
"cloud_provider": cloud_provider,
158150
}
159151

160-
response = req.post(
161-
f"{self._base_uri}/v1/instances",
162-
json=data,
163-
headers=self._build_header(),
164-
)
152+
response = self._request_session.post(f"{self._base_uri}/v1/instances", json=data)
165153

166-
try:
167-
response.raise_for_status()
168-
except HTTPError as e:
169-
print(response.json())
170-
raise e
154+
self._check_code(response)
171155

172156
return InstanceCreateDetails.from_json(response.json()["data"])
173157

174158
def delete_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
175-
response = req.delete(
176-
f"{self._base_uri}/v1/instances/{instance_id}",
177-
headers=self._build_header(),
178-
)
159+
response = self._request_session.delete(f"{self._base_uri}/v1/instances/{instance_id}")
179160

180161
if response.status_code == 404:
181162
return None
182163

183-
response.raise_for_status()
164+
self._check_code(response)
184165

185166
return InstanceSpecificDetails.fromJson(response.json()["data"])
186167

187168
def list_instances(self) -> List[InstanceDetails]:
188-
response = req.get(
189-
f"{self._base_uri}/v1/instances",
190-
headers=self._build_header(),
191-
params={"tenantId": self._tenant_id},
192-
)
169+
response = self._request_session.get(f"{self._base_uri}/v1/instances", params={"tenantId": self._tenant_id})
193170

194-
response.raise_for_status()
171+
self._check_code(response)
195172

196173
raw_data = response.json()["data"]
197174

198175
return [InstanceDetails.fromJson(i) for i in raw_data]
199176

200177
def list_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
201-
response = req.get(
202-
f"{self._base_uri}/v1/instances/{instance_id}",
203-
headers=self._build_header(),
204-
)
178+
response = self._request_session.get(f"{self._base_uri}/v1/instances/{instance_id}")
205179

206180
if response.status_code == 404:
207181
return None
208182

209-
response.raise_for_status()
183+
self._check_code(response)
210184

211185
raw_data = response.json()["data"]
212186

@@ -246,17 +220,14 @@ def estimate_size(
246220
"instance_type": "dsenterprise",
247221
}
248222

249-
response = req.post(f"{self._base_uri}/v1/instances/sizing", headers=self._build_header(), json=data)
250-
response.raise_for_status()
223+
response = self._request_session.post(f"{self._base_uri}/v1/instances/sizing", json=data)
224+
self._check_code(response)
251225

252226
return EstimationDetails.from_json(response.json()["data"])
253227

254228
def _get_tenant_id(self) -> str:
255-
response = req.get(
256-
f"{self._base_uri}/v1/tenants",
257-
headers=self._build_header(),
258-
)
259-
response.raise_for_status()
229+
response = self._request_session.get(f"{self._base_uri}/v1/tenants")
230+
self._check_code(response)
260231

261232
raw_data = response.json()["data"]
262233

@@ -270,36 +241,68 @@ def _get_tenant_id(self) -> str:
270241

271242
def tenant_details(self) -> TenantDetails:
272243
if not self._tenant_details:
273-
response = req.get(
274-
f"{self._base_uri}/v1/tenants/{self._tenant_id}",
275-
headers=self._build_header(),
276-
)
277-
response.raise_for_status()
244+
response = self._request_session.get(f"{self._base_uri}/v1/tenants/{self._tenant_id}")
245+
self._check_code(response)
278246
self._tenant_details = TenantDetails.from_json(response.json()["data"])
279247
return self._tenant_details
280248

281-
def _build_header(self) -> Dict[str, str]:
282-
return {"Authorization": f"Bearer {self._auth_token()}", "User-agent": f"neo4j-graphdatascience-v{__version__}"}
283-
284-
def _auth_token(self) -> str:
285-
if self._token is None or self._token.is_expired():
286-
self._token = self._update_token()
287-
return self._token.access_token
288-
289-
def _update_token(self) -> AuraAuthToken:
290-
data = {
291-
"grant_type": "client_credentials",
292-
}
293-
294-
self._logger.debug("Updating oauth token")
295-
296-
response = req.post(
297-
f"{self._base_uri}/oauth/token", data=data, auth=(self._credentials[0], self._credentials[1])
298-
)
299-
300-
response.raise_for_status()
301-
302-
return AuraApi.AuraAuthToken(response.json())
249+
def _check_code(self, resp: requests.Response) -> None:
250+
if resp.status_code >= 400:
251+
raise AuraApiError(
252+
f"Request for {resp.url} failed with status code {resp.status_code} - {resp.reason}: {resp.text}",
253+
status_code=resp.status_code,
254+
)
303255

304256
def _instance_type(self) -> str:
305257
return "enterprise-ds" if not self._dev_env else "professional-ds"
258+
259+
class Auth(requests.auth.AuthBase):
260+
class Token:
261+
access_token: str
262+
expires_in: int
263+
token_type: str
264+
265+
def __init__(self, json: Dict[str, Any]) -> None:
266+
self.access_token = json["access_token"]
267+
self.token_type = json["token_type"]
268+
269+
expires_in: int = json["expires_in"]
270+
refresh_in: int = expires_in if expires_in <= 10 else expires_in - 10
271+
# avoid token expiry during request send by refreshing 10 seconds earlier
272+
self.refresh_at = int(time.time()) + refresh_in
273+
274+
def should_refresh(self) -> bool:
275+
return self.refresh_at >= int(time.time())
276+
277+
def __init__(self, oauth_url: str, credentials: Tuple[str, str]) -> None:
278+
self._token: Optional[AuraApi.Auth.Token] = None
279+
self._logger = logging.getLogger()
280+
self._oauth_url = oauth_url
281+
self._credentials = credentials
282+
283+
def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
284+
r.headers["Authorization"] = f"Bearer {self._auth_token()}"
285+
return r
286+
287+
def _auth_token(self) -> str:
288+
if self._token is None or self._token.should_refresh():
289+
self._token = self._update_token()
290+
return self._token.access_token
291+
292+
def _update_token(self) -> AuraApi.Auth.Token:
293+
data = {
294+
"grant_type": "client_credentials",
295+
}
296+
297+
self._logger.debug("Updating oauth token")
298+
299+
resp = requests.post(self._oauth_url, data=data, auth=(self._credentials[0], self._credentials[1]))
300+
301+
if resp.status_code >= 400:
302+
raise AuraApiError(
303+
"Failed to authorize with provided client credentials: "
304+
+ f"{resp.status_code} - {resp.reason}, {resp.text}",
305+
status_code=resp.status_code,
306+
)
307+
308+
return AuraApi.Auth.Token(resp.json())

graphdatascience/session/dedicated_sessions.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
from datetime import datetime, timedelta, timezone
66
from typing import List, Optional
77

8-
from requests.exceptions import HTTPError
9-
108
from graphdatascience.session.algorithm_category import AlgorithmCategory
11-
from graphdatascience.session.aura_api import AuraApi
9+
from graphdatascience.session.aura_api import AuraApi, AuraApiError
1210
from graphdatascience.session.aura_api_responses import SessionDetails
1311
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
1412
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
@@ -97,9 +95,9 @@ def list(self) -> List[SessionInfo]:
9795
for db in dbs:
9896
try:
9997
sessions.extend(self._aura_api.list_sessions(db.id))
100-
except HTTPError as e:
98+
except AuraApiError as e:
10199
# ignore 404 errors when listing sessions as it could mean paused sessions or deleted sessions
102-
if e.response.status_code != 404:
100+
if e.status_code != 404:
103101
raise e
104102

105103
return [SessionInfo.from_session_details(i) for i in sessions]
@@ -108,9 +106,9 @@ def _find_existing_session(self, session_name: str, dbid: str) -> Optional[Sessi
108106
matched_sessions: List[SessionDetails] = []
109107
try:
110108
matched_sessions = [s for s in self._aura_api.list_sessions(dbid) if s.name == session_name]
111-
except HTTPError as e:
109+
except AuraApiError as e:
112110
# ignore 404 errors when listing sessions as it could mean paused sessions or deleted sessions
113-
if e.response.status_code != 404:
111+
if e.status_code != 404:
114112
raise e
115113

116114
if len(matched_sessions) == 0:

0 commit comments

Comments
 (0)