Skip to content

Commit b476162

Browse files
authored
Merge pull request #662 from DarthMax/dedicated-session-check-memory
Dedicated session memory checks
2 parents 2fbe7b2 + 1e0e3c8 commit b476162

File tree

11 files changed

+180
-92
lines changed

11 files changed

+180
-92
lines changed

graphdatascience/session/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from .dbms_connection_info import DbmsConnectionInfo
33
from .gds_sessions import AuraAPICredentials, GdsSessions
44
from .session_info import SessionInfo
5-
from .session_sizes import SessionMemory
5+
from .session_sizes import SessionMemory, SessionMemoryValue
66

77
__all__ = [
88
"GdsSessions",
99
"SessionInfo",
1010
"DbmsConnectionInfo",
1111
"AuraAPICredentials",
1212
"SessionMemory",
13+
"SessionMemoryValue",
1314
"AlgorithmCategory",
1415
]

graphdatascience/session/aura_api.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
TenantDetails,
2020
WaitResult,
2121
)
22+
from graphdatascience.session.session_sizes import SessionMemoryValue
2223
from graphdatascience.version import __version__
2324

2425

@@ -62,11 +63,11 @@ def extract_id(uri: str) -> str:
6263

6364
return host.split(".")[0].split("-")[0]
6465

65-
def create_session(self, name: str, dbid: str, pwd: str, memory: str) -> SessionDetails:
66+
def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryValue) -> SessionDetails:
6667
response = req.post(
6768
f"{self._base_uri}/v1beta5/data-science/sessions",
6869
headers=self._build_header(),
69-
json={"name": name, "instance_id": dbid, "password": pwd, "memory": memory},
70+
json={"name": name, "instance_id": dbid, "password": pwd, "memory": memory.value},
7071
)
7172

7273
response.raise_for_status()
@@ -141,12 +142,14 @@ def delete_session(self, session_id: str, dbid: str) -> bool:
141142

142143
return False
143144

144-
def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails:
145+
def create_instance(
146+
self, name: str, memory: SessionMemoryValue, cloud_provider: str, region: str
147+
) -> InstanceCreateDetails:
145148
tenant_details = self.tenant_details()
146149

147150
data = {
148151
"name": name,
149-
"memory": memory,
152+
"memory": memory.value,
150153
"version": "5",
151154
"region": region,
152155
"type": tenant_details.ds_type,

graphdatascience/session/aura_api_responses.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99

1010
from pandas import Timedelta
1111

12+
from .session_sizes import SessionMemoryValue
13+
1214

1315
@dataclass(repr=True, frozen=True)
1416
class SessionDetails:
1517
id: str
1618
name: str
1719
instance_id: str
18-
memory: str
20+
memory: SessionMemoryValue
1921
status: str
2022
host: str
2123
created_at: datetime
@@ -31,7 +33,7 @@ def fromJson(cls, json: Dict[str, Any]) -> SessionDetails:
3133
id=json["id"],
3234
name=json["name"],
3335
instance_id=json["instance_id"],
34-
memory=json["memory"],
36+
memory=SessionMemoryValue.fromApiResponse(json["memory"]),
3537
status=json["status"],
3638
host=json["host"],
3739
expiry_date=TimeParser.fromisoformat(expiry_date) if expiry_date else None,
@@ -67,7 +69,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceDetails:
6769
class InstanceSpecificDetails(InstanceDetails):
6870
status: str
6971
connection_url: str
70-
memory: str
72+
memory: SessionMemoryValue
7173
type: str
7274
region: str
7375

@@ -80,7 +82,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceSpecificDetails:
8082
cloud_provider=json["cloud_provider"],
8183
status=json["status"],
8284
connection_url=json.get("connection_url", ""),
83-
memory=json.get("memory", ""),
85+
memory=SessionMemoryValue.fromApiResponse(json.get("memory", "")),
8486
type=json["type"],
8587
region=json["region"],
8688
)

graphdatascience/session/aurads_sessions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
1717
from graphdatascience.session.region_suggester import closest_match
1818
from graphdatascience.session.session_info import SessionInfo
19-
from graphdatascience.session.session_sizes import SessionMemory
19+
from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue
2020

2121

2222
class AuraDsSessions:
@@ -41,7 +41,7 @@ def estimate(
4141
ResourceWarning,
4242
)
4343

44-
return SessionMemory(estimation.recommended_size)
44+
return SessionMemory(SessionMemoryValue(estimation.recommended_size))
4545

4646
def get_or_create(
4747
self,
@@ -54,13 +54,13 @@ def get_or_create(
5454
if existing_session:
5555
session_id = existing_session.id
5656
# 0MB is AuraAPI default value for memory if none can be retrieved
57-
if existing_session.memory != "0MB" and existing_session.memory != memory.value:
57+
if existing_session.memory.value != "0MB" and existing_session.memory != memory.value:
5858
raise ValueError(
59-
f"Session `{session_name}` already exists with memory `{existing_session.memory}`. "
59+
f"Session `{session_name}` already exists with memory `{existing_session.memory.value}`. "
6060
f"Requested memory `{memory.value}` does not match."
6161
)
6262
else:
63-
create_details = self._create_session(session_name, memory, db_connection)
63+
create_details = self._create_session(session_name, memory.value, db_connection)
6464
session_id = create_details.id
6565

6666
wait_result = self._aura_api.wait_for_instance_running(session_id)
@@ -118,7 +118,7 @@ def _find_existing_session(self, session_name: str) -> Optional[InstanceSpecific
118118
return self._aura_api.list_instance(matched_instances[0].id)
119119

120120
def _create_session(
121-
self, session_name: str, memory: SessionMemory, db_connection: DbmsConnectionInfo
121+
self, session_name: str, memory: SessionMemoryValue, db_connection: DbmsConnectionInfo
122122
) -> InstanceCreateDetails:
123123
db_instance_id = AuraApi.extract_id(db_connection.uri)
124124
db_instance = self._aura_api.list_instance(db_instance_id)
@@ -128,7 +128,7 @@ def _create_session(
128128
region = self._ds_region(db_instance.region, db_instance.cloud_provider)
129129

130130
create_details = self._aura_api.create_instance(
131-
SessionNameHelper.instance_name(session_name), memory.value, db_instance.cloud_provider, region
131+
SessionNameHelper.instance_name(session_name), memory, db_instance.cloud_provider, region
132132
)
133133
return create_details
134134

graphdatascience/session/dedicated_sessions.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
1212
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
1313
from graphdatascience.session.session_info import SessionInfo
14-
from graphdatascience.session.session_sizes import SessionMemory
14+
from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue
1515

1616

1717
class DedicatedSessions:
@@ -35,7 +35,7 @@ def estimate(
3535
ResourceWarning,
3636
)
3737

38-
return SessionMemory(estimation.recommended_size)
38+
return SessionMemory(SessionMemoryValue(estimation.recommended_size))
3939

4040
def get_or_create(
4141
self,
@@ -52,9 +52,10 @@ def get_or_create(
5252
# TODO configure session size (and check existing_session has same size)
5353
if existing_session:
5454
self._check_expiry_date(existing_session)
55+
self._check_memory_configuration(existing_session, memory.value)
5556
session_id = existing_session.id
5657
else:
57-
create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory)
58+
create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory.value)
5859
session_id = create_details.id
5960

6061
wait_result = self._aura_api.wait_for_session_running(session_id, dbid)
@@ -108,7 +109,7 @@ def _find_existing_session(self, session_name: str, dbid: str) -> Optional[Sessi
108109
return matched_sessions[0]
109110

110111
def _create_session(
111-
self, session_name: str, dbid: str, dburi: str, pwd: str, memory: SessionMemory
112+
self, session_name: str, dbid: str, dburi: str, pwd: str, memory: SessionMemoryValue
112113
) -> SessionDetails:
113114
db_instance = self._aura_api.list_instance(dbid)
114115
if not db_instance:
@@ -118,7 +119,7 @@ def _create_session(
118119
name=session_name,
119120
dbid=dbid,
120121
pwd=pwd,
121-
memory=memory.value,
122+
memory=memory,
122123
)
123124
return create_details
124125

@@ -139,6 +140,15 @@ def _check_expiry_date(self, session: SessionDetails) -> None:
139140
if until_expiry < timedelta(days=1):
140141
raise Warning(f"Session `{session.name}` is expiring in less than a day.")
141142

143+
def _check_memory_configuration(
144+
self, existing_session: SessionDetails, requested_memory: SessionMemoryValue
145+
) -> None:
146+
if existing_session.memory != requested_memory:
147+
raise RuntimeError(
148+
f"Session `{existing_session.name}` exists with a different memory configuration. "
149+
f"Current: {existing_session.memory}, Requested: {requested_memory}."
150+
)
151+
142152
@classmethod
143153
def _fail_ambiguous_session(cls, session_name: str, sessions: List[SessionDetails]) -> None:
144154
candidates = [i.id for i in sessions]

graphdatascience/session/session_info.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Optional
66

77
from graphdatascience.session.aura_api_responses import SessionDetails
8+
from graphdatascience.session.session_sizes import SessionMemoryValue
89

910

1011
@dataclass(frozen=True)
@@ -18,7 +19,7 @@ class SessionInfo:
1819
"""
1920

2021
name: str
21-
memory: str
22+
memory: SessionMemoryValue
2223

2324
@classmethod
2425
def from_session_details(cls, details: SessionDetails) -> ExtendedSessionInfo:

graphdatascience/session/session_sizes.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,53 @@
1+
from dataclasses import dataclass
12
from enum import Enum
23
from typing import List
34

45

6+
@dataclass(frozen=True)
7+
class SessionMemoryValue:
8+
value: str
9+
10+
def __str__(self) -> str:
11+
return self.value
12+
13+
@staticmethod
14+
def fromApiResponse(value: str) -> "SessionMemoryValue":
15+
"""
16+
Converts the string value from an API response to a SessionMemory enumeration value.
17+
18+
Args:
19+
value: The string value from the API response.
20+
21+
Returns:
22+
The SessionMemory enumeration value.
23+
24+
"""
25+
if value == "":
26+
raise ValueError("memory configuration cannot be empty")
27+
28+
return SessionMemoryValue(value.replace("Gi", "GB"))
29+
30+
531
class SessionMemory(Enum):
632
"""
733
Enumeration representing session main memory configurations.
834
"""
935

10-
m_8GB = "8GB"
11-
m_16GB = "16GB"
12-
m_24GB = "24GB"
13-
m_32GB = "32GB"
14-
m_48GB = "48GB"
15-
m_64GB = "64GB"
16-
m_96GB = "96GB"
17-
m_128GB = "128GB"
18-
m_192GB = "192GB"
19-
m_256GB = "256GB"
20-
m_384GB = "384GB"
36+
m_4GB = SessionMemoryValue("4GB")
37+
m_8GB = SessionMemoryValue("8GB")
38+
m_16GB = SessionMemoryValue("16GB")
39+
m_24GB = SessionMemoryValue("24GB")
40+
m_32GB = SessionMemoryValue("32GB")
41+
m_48GB = SessionMemoryValue("48GB")
42+
m_64GB = SessionMemoryValue("64GB")
43+
m_96GB = SessionMemoryValue("96GB")
44+
m_128GB = SessionMemoryValue("128GB")
45+
m_192GB = SessionMemoryValue("192GB")
46+
m_256GB = SessionMemoryValue("256GB")
47+
m_384GB = SessionMemoryValue("384GB")
2148

2249
@classmethod
23-
def all_values(cls) -> List[str]:
50+
def all_values(cls) -> List[SessionMemoryValue]:
2451
"""
2552
All supported memory configurations.
2653

0 commit comments

Comments
 (0)