Skip to content

Commit 18ead39

Browse files
committed
1 parent d647bd3 commit 18ead39

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

graphdatascience/session/aura_api.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
from graphdatascience.session.session_sizes import SessionMemoryValue
2222
from graphdatascience.version import __version__
2323

24+
2425
class AuraApiError(Exception):
2526
def __init__(self, message: str, status_code: int):
2627
super().__init__(self, message)
2728
self.status_code = status_code
2829

30+
2931
class AuraApi:
3032
class AuraAuthToken:
3133
access_token: str
@@ -56,6 +58,7 @@ def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str]
5658
self._logger = logging.getLogger()
5759
self._tenant_id = tenant_id if tenant_id else self._get_tenant_id()
5860
self._tenant_details: Optional[TenantDetails] = None
61+
self._request_session = requests.Session()
5962

6063
@staticmethod
6164
def extract_id(uri: str) -> str:
@@ -67,7 +70,7 @@ def extract_id(uri: str) -> str:
6770
return host.split(".")[0].split("-")[0]
6871

6972
def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryValue) -> SessionDetails:
70-
response = requests.post(
73+
response = self._request_session.post(
7174
f"{self._base_uri}/v1beta5/data-science/sessions",
7275
headers=self._build_header(),
7376
json={"name": name, "instance_id": dbid, "password": pwd, "memory": memory.value},
@@ -78,7 +81,7 @@ def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryVa
7881
return SessionDetails.fromJson(response.json())
7982

8083
def list_session(self, session_id: str, dbid: str) -> Optional[SessionDetails]:
81-
response = requests.get(
84+
response = self._request_session.get(
8285
f"{self._base_uri}/v1beta5/data-science/sessions/{session_id}?instanceId={dbid}",
8386
headers=self._build_header(),
8487
)
@@ -91,7 +94,7 @@ def list_session(self, session_id: str, dbid: str) -> Optional[SessionDetails]:
9194
return SessionDetails.fromJson(response.json())
9295

9396
def list_sessions(self, dbid: str) -> List[SessionDetails]:
94-
response = requests.get(
97+
response = self._request_session.get(
9598
f"{self._base_uri}/v1beta5/data-science/sessions?instanceId={dbid}",
9699
headers=self._build_header(),
97100
)
@@ -130,7 +133,7 @@ def wait_for_session_running(
130133
)
131134

132135
def delete_session(self, session_id: str, dbid: str) -> bool:
133-
response = requests.delete(
136+
response = self._request_session.delete(
134137
f"{self._base_uri}/v1beta5/data-science/sessions/{session_id}",
135138
headers=self._build_header(),
136139
json={"instance_id": dbid},
@@ -160,7 +163,7 @@ def create_instance(
160163
"cloud_provider": cloud_provider,
161164
}
162165

163-
response = requests.post(
166+
response = self._request_session.post(
164167
f"{self._base_uri}/v1/instances",
165168
json=data,
166169
headers=self._build_header(),
@@ -171,7 +174,7 @@ def create_instance(
171174
return InstanceCreateDetails.from_json(response.json()["data"])
172175

173176
def delete_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
174-
response = requests.delete(
177+
response = self._request_session.delete(
175178
f"{self._base_uri}/v1/instances/{instance_id}",
176179
headers=self._build_header(),
177180
)
@@ -184,7 +187,7 @@ def delete_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]
184187
return InstanceSpecificDetails.fromJson(response.json()["data"])
185188

186189
def list_instances(self) -> List[InstanceDetails]:
187-
response = requests.get(
190+
response = self._request_session.get(
188191
f"{self._base_uri}/v1/instances",
189192
headers=self._build_header(),
190193
params={"tenantId": self._tenant_id},
@@ -197,7 +200,7 @@ def list_instances(self) -> List[InstanceDetails]:
197200
return [InstanceDetails.fromJson(i) for i in raw_data]
198201

199202
def list_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
200-
response = requests.get(
203+
response = self._request_session.get(
201204
f"{self._base_uri}/v1/instances/{instance_id}",
202205
headers=self._build_header(),
203206
)
@@ -245,13 +248,15 @@ def estimate_size(
245248
"instance_type": "dsenterprise",
246249
}
247250

248-
response = requests.post(f"{self._base_uri}/v1/instances/sizing", headers=self._build_header(), json=data)
251+
response = self._request_session.post(
252+
f"{self._base_uri}/v1/instances/sizing", headers=self._build_header(), json=data
253+
)
249254
self._check_code(response)
250255

251256
return EstimationDetails.from_json(response.json()["data"])
252257

253258
def _get_tenant_id(self) -> str:
254-
response = requests.get(
259+
response = self._request_session.get(
255260
f"{self._base_uri}/v1/tenants",
256261
headers=self._build_header(),
257262
)
@@ -269,7 +274,7 @@ def _get_tenant_id(self) -> str:
269274

270275
def tenant_details(self) -> TenantDetails:
271276
if not self._tenant_details:
272-
response = requests.get(
277+
response = self._request_session.get(
273278
f"{self._base_uri}/v1/tenants/{self._tenant_id}",
274279
headers=self._build_header(),
275280
)
@@ -292,7 +297,7 @@ def _update_token(self) -> AuraAuthToken:
292297

293298
self._logger.debug("Updating oauth token")
294299

295-
response = requests.post(
300+
response = self._request_session.post(
296301
f"{self._base_uri}/oauth/token", data=data, auth=(self._credentials[0], self._credentials[1])
297302
)
298303

0 commit comments

Comments
 (0)