Skip to content

Commit 87e982d

Browse files
committed
Extract Auth logic into middleware
1 parent 18ead39 commit 87e982d

File tree

2 files changed

+63
-75
lines changed

2 files changed

+63
-75
lines changed

graphdatascience/session/aura_api.py

Lines changed: 61 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import logging
44
import os
55
import time
6-
from typing import Any, Dict, List, Optional
6+
from typing import Any, Dict, List, Optional, Tuple
77
from urllib.parse import urlparse
88

99
import requests
10+
import requests.auth
1011

1112
from graphdatascience.session.algorithm_category import AlgorithmCategory
1213
from graphdatascience.session.aura_api_responses import (
@@ -29,20 +30,6 @@ def __init__(self, message: str, status_code: int):
2930

3031

3132
class AuraApi:
32-
class AuraAuthToken:
33-
access_token: str
34-
expires_in: int
35-
token_type: str
36-
37-
def __init__(self, json: Dict[str, Any]) -> None:
38-
self.access_token = json["access_token"]
39-
expires_in: int = json["expires_in"]
40-
self.expires_at = int(time.time()) + expires_in
41-
self.token_type = json["token_type"]
42-
43-
def is_expired(self) -> bool:
44-
return self.expires_at >= int(time.time())
45-
4633
def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str] = None) -> None:
4734
self._dev_env = os.environ.get("AURA_ENV")
4835

@@ -53,12 +40,13 @@ def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str]
5340
else:
5441
self._base_uri = f"https://api-{self._dev_env}.neo4j-dev.io"
5542

56-
self._credentials = (client_id, client_secret)
57-
self._token: Optional[AuraApi.AuraAuthToken] = None
43+
self._auth = AuraApi.Auth(oauth_url=f"{self._base_uri}/oauth/token", credentials=(client_id, client_secret))
5844
self._logger = logging.getLogger()
5945
self._tenant_id = tenant_id if tenant_id else self._get_tenant_id()
6046
self._tenant_details: Optional[TenantDetails] = None
6147
self._request_session = requests.Session()
48+
self._request_session.headers = {"User-agent": f"neo4j-graphdatascience-v{__version__}"}
49+
self._request_session.auth = self._auth
6250

6351
@staticmethod
6452
def extract_id(uri: str) -> str:
@@ -72,7 +60,6 @@ def extract_id(uri: str) -> str:
7260
def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryValue) -> SessionDetails:
7361
response = self._request_session.post(
7462
f"{self._base_uri}/v1beta5/data-science/sessions",
75-
headers=self._build_header(),
7663
json={"name": name, "instance_id": dbid, "password": pwd, "memory": memory.value},
7764
)
7865

@@ -83,7 +70,6 @@ def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryVa
8370
def list_session(self, session_id: str, dbid: str) -> Optional[SessionDetails]:
8471
response = self._request_session.get(
8572
f"{self._base_uri}/v1beta5/data-science/sessions/{session_id}?instanceId={dbid}",
86-
headers=self._build_header(),
8773
)
8874

8975
if response.status_code == 404:
@@ -96,7 +82,6 @@ def list_session(self, session_id: str, dbid: str) -> Optional[SessionDetails]:
9682
def list_sessions(self, dbid: str) -> List[SessionDetails]:
9783
response = self._request_session.get(
9884
f"{self._base_uri}/v1beta5/data-science/sessions?instanceId={dbid}",
99-
headers=self._build_header(),
10085
)
10186

10287
self._check_code(response)
@@ -135,7 +120,6 @@ def wait_for_session_running(
135120
def delete_session(self, session_id: str, dbid: str) -> bool:
136121
response = self._request_session.delete(
137122
f"{self._base_uri}/v1beta5/data-science/sessions/{session_id}",
138-
headers=self._build_header(),
139123
json={"instance_id": dbid},
140124
)
141125

@@ -163,21 +147,14 @@ def create_instance(
163147
"cloud_provider": cloud_provider,
164148
}
165149

166-
response = self._request_session.post(
167-
f"{self._base_uri}/v1/instances",
168-
json=data,
169-
headers=self._build_header(),
170-
)
150+
response = self._request_session.post(f"{self._base_uri}/v1/instances", json=data)
171151

172152
self._check_code(response)
173153

174154
return InstanceCreateDetails.from_json(response.json()["data"])
175155

176156
def delete_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
177-
response = self._request_session.delete(
178-
f"{self._base_uri}/v1/instances/{instance_id}",
179-
headers=self._build_header(),
180-
)
157+
response = self._request_session.delete(f"{self._base_uri}/v1/instances/{instance_id}")
181158

182159
if response.status_code == 404:
183160
return None
@@ -187,11 +164,7 @@ def delete_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]
187164
return InstanceSpecificDetails.fromJson(response.json()["data"])
188165

189166
def list_instances(self) -> List[InstanceDetails]:
190-
response = self._request_session.get(
191-
f"{self._base_uri}/v1/instances",
192-
headers=self._build_header(),
193-
params={"tenantId": self._tenant_id},
194-
)
167+
response = self._request_session.get(f"{self._base_uri}/v1/instances", params={"tenantId": self._tenant_id})
195168

196169
self._check_code(response)
197170

@@ -200,10 +173,7 @@ def list_instances(self) -> List[InstanceDetails]:
200173
return [InstanceDetails.fromJson(i) for i in raw_data]
201174

202175
def list_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
203-
response = self._request_session.get(
204-
f"{self._base_uri}/v1/instances/{instance_id}",
205-
headers=self._build_header(),
206-
)
176+
response = self._request_session.get(f"{self._base_uri}/v1/instances/{instance_id}")
207177

208178
if response.status_code == 404:
209179
return None
@@ -248,18 +218,13 @@ def estimate_size(
248218
"instance_type": "dsenterprise",
249219
}
250220

251-
response = self._request_session.post(
252-
f"{self._base_uri}/v1/instances/sizing", headers=self._build_header(), json=data
253-
)
221+
response = self._request_session.post(f"{self._base_uri}/v1/instances/sizing", json=data)
254222
self._check_code(response)
255223

256224
return EstimationDetails.from_json(response.json()["data"])
257225

258226
def _get_tenant_id(self) -> str:
259-
response = self._request_session.get(
260-
f"{self._base_uri}/v1/tenants",
261-
headers=self._build_header(),
262-
)
227+
response = self._request_session.get(f"{self._base_uri}/v1/tenants")
263228
self._check_code(response)
264229

265230
raw_data = response.json()["data"]
@@ -274,37 +239,11 @@ def _get_tenant_id(self) -> str:
274239

275240
def tenant_details(self) -> TenantDetails:
276241
if not self._tenant_details:
277-
response = self._request_session.get(
278-
f"{self._base_uri}/v1/tenants/{self._tenant_id}",
279-
headers=self._build_header(),
280-
)
242+
response = self._request_session.get(f"{self._base_uri}/v1/tenants/{self._tenant_id}")
281243
self._check_code(response)
282244
self._tenant_details = TenantDetails.from_json(response.json()["data"])
283245
return self._tenant_details
284246

285-
def _build_header(self) -> Dict[str, str]:
286-
return {"Authorization": f"Bearer {self._auth_token()}", "User-agent": f"neo4j-graphdatascience-v{__version__}"}
287-
288-
def _auth_token(self) -> str:
289-
if self._token is None or self._token.is_expired():
290-
self._token = self._update_token()
291-
return self._token.access_token
292-
293-
def _update_token(self) -> AuraAuthToken:
294-
data = {
295-
"grant_type": "client_credentials",
296-
}
297-
298-
self._logger.debug("Updating oauth token")
299-
300-
response = self._request_session.post(
301-
f"{self._base_uri}/oauth/token", data=data, auth=(self._credentials[0], self._credentials[1])
302-
)
303-
304-
self._check_code(response)
305-
306-
return AuraApi.AuraAuthToken(response.json())
307-
308247
def _check_code(self, resp: requests.Response) -> None:
309248
if resp.status_code >= 400:
310249
raise AuraApiError(
@@ -314,3 +253,52 @@ def _check_code(self, resp: requests.Response) -> None:
314253

315254
def _instance_type(self) -> str:
316255
return "enterprise-ds" if not self._dev_env else "professional-ds"
256+
257+
class Auth(requests.auth.AuthBase):
258+
class Token:
259+
access_token: str
260+
expires_in: int
261+
token_type: str
262+
263+
def __init__(self, json: Dict[str, Any]) -> None:
264+
self.access_token = json["access_token"]
265+
expires_in: int = json["expires_in"]
266+
self.expires_at = int(time.time()) + expires_in
267+
self.token_type = json["token_type"]
268+
269+
# TODO add a buffer of 10s to avoid nearly expiring tokens
270+
def is_expired(self) -> bool:
271+
return self.expires_at >= int(time.time())
272+
273+
def __init__(self, oauth_url: str, credentials: Tuple[str, str]) -> None:
274+
self._token: Optional[AuraApi.Auth.Token] = None
275+
self._logger = logging.getLogger()
276+
self._oauth_url = oauth_url
277+
self._credentials = credentials
278+
279+
def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
280+
r.headers["Authorization"] = f"Bearer {self._auth_token()}"
281+
return r
282+
283+
def _auth_token(self) -> str:
284+
if self._token is None or self._token.is_expired():
285+
self._token = self._update_token()
286+
return self._token.access_token
287+
288+
def _update_token(self) -> AuraApi.Auth.Token:
289+
data = {
290+
"grant_type": "client_credentials",
291+
}
292+
293+
self._logger.debug("Updating oauth token")
294+
295+
resp = requests.post(self._oauth_url, data=data, auth=(self._credentials[0], self._credentials[1]))
296+
297+
if resp.status_code >= 400:
298+
raise AuraApiError(
299+
"Failed to authorize with provided client credentials: "
300+
+ f"{resp.status_code} - {resp.reason}, {resp.text}",
301+
status_code=resp.status_code,
302+
)
303+
304+
return AuraApi.Auth.Token(resp.json())

graphdatascience/tests/unit/test_aura_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,14 +355,14 @@ def test_auth_token(requests_mock: Mocker) -> None:
355355
json={"access_token": "very_short_token", "expires_in": 1, "token_type": "Bearer"},
356356
)
357357

358-
assert api._auth_token() == "very_short_token"
358+
assert api._auth._auth_token() == "very_short_token"
359359

360360
requests_mock.post(
361361
"https://api.neo4j.io/oauth/token",
362362
json={"access_token": "longer_token", "expires_in": 3600, "token_type": "Bearer"},
363363
)
364364

365-
assert api._auth_token() == "longer_token"
365+
assert api._auth._auth_token() == "longer_token"
366366

367367

368368
def test_derive_tenant(requests_mock: Mocker) -> None:

0 commit comments

Comments
 (0)