Skip to content

Commit c2c4e48

Browse files
DarthMaxFlorentinD
authored andcommitted
Use a data class instead of an enum to represent session memory
* We will keep the enum to make the selection easier for users
1 parent 3938b69 commit c2c4e48

File tree

10 files changed

+118
-106
lines changed

10 files changed

+118
-106
lines changed

graphdatascience/session/aura_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import requests as req
1010
from requests import HTTPError
1111

12-
from graphdatascience.session.session_sizes import SessionMemory
1312
from graphdatascience.session.algorithm_category import AlgorithmCategory
1413
from graphdatascience.session.aura_api_responses import (
1514
EstimationDetails,
@@ -20,6 +19,7 @@
2019
TenantDetails,
2120
WaitResult,
2221
)
22+
from graphdatascience.session.session_sizes import SessionMemoryValue
2323
from graphdatascience.version import __version__
2424

2525

@@ -63,7 +63,7 @@ def extract_id(uri: str) -> str:
6363

6464
return host.split(".")[0].split("-")[0]
6565

66-
def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemory) -> SessionDetails:
66+
def create_session(self, name: str, dbid: str, pwd: str, memory: SessionMemoryValue) -> SessionDetails:
6767
response = req.post(
6868
f"{self._base_uri}/v1beta5/data-science/sessions",
6969
headers=self._build_header(),
@@ -143,7 +143,7 @@ def delete_session(self, session_id: str, dbid: str) -> bool:
143143
return False
144144

145145
def create_instance(
146-
self, name: str, memory: SessionMemory, cloud_provider: str, region: str
146+
self, name: str, memory: SessionMemoryValue, cloud_provider: str, region: str
147147
) -> InstanceCreateDetails:
148148
tenant_details = self.tenant_details()
149149

graphdatascience/session/aura_api_responses.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99

1010
from pandas import Timedelta
1111

12-
from .session_sizes import SessionMemory
12+
from .session_sizes import SessionMemoryValue
1313

1414

1515
@dataclass(repr=True, frozen=True)
1616
class SessionDetails:
1717
id: str
1818
name: str
1919
instance_id: str
20-
memory: SessionMemory
20+
memory: SessionMemoryValue
2121
status: str
2222
host: str
2323
created_at: datetime
@@ -33,7 +33,7 @@ def fromJson(cls, json: Dict[str, Any]) -> SessionDetails:
3333
id=json["id"],
3434
name=json["name"],
3535
instance_id=json["instance_id"],
36-
memory=SessionMemory.fromApiResponse(json["memory"]),
36+
memory=SessionMemoryValue.fromApiResponse(json["memory"]),
3737
status=json["status"],
3838
host=json["host"],
3939
expiry_date=TimeParser.fromisoformat(expiry_date) if expiry_date else None,
@@ -69,7 +69,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceDetails:
6969
class InstanceSpecificDetails(InstanceDetails):
7070
status: str
7171
connection_url: str
72-
memory: SessionMemory
72+
memory: SessionMemoryValue
7373
type: str
7474
region: str
7575

@@ -82,7 +82,7 @@ def fromJson(cls, json: Dict[str, Any]) -> InstanceSpecificDetails:
8282
cloud_provider=json["cloud_provider"],
8383
status=json["status"],
8484
connection_url=json.get("connection_url", ""),
85-
memory=SessionMemory.fromApiResponse(json.get("memory", "")),
85+
memory=SessionMemoryValue.fromApiResponse(json.get("memory", "")),
8686
type=json["type"],
8787
region=json["region"],
8888
)

graphdatascience/session/aurads_sessions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
1717
from graphdatascience.session.region_suggester import closest_match
1818
from graphdatascience.session.session_info import SessionInfo
19-
from graphdatascience.session.session_sizes import SessionMemory
19+
from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue
2020

2121

2222
class AuraDsSessions:
@@ -41,7 +41,7 @@ def estimate(
4141
ResourceWarning,
4242
)
4343

44-
return SessionMemory(estimation.recommended_size)
44+
return SessionMemory(SessionMemoryValue(estimation.recommended_size))
4545

4646
def get_or_create(
4747
self,
@@ -54,13 +54,13 @@ def get_or_create(
5454
if existing_session:
5555
session_id = existing_session.id
5656
# 0MB is AuraAPI default value for memory if none can be retrieved
57-
if existing_session.memory != memory:
57+
if existing_session.memory.value != "0MB" and existing_session.memory != memory.value:
5858
raise ValueError(
5959
f"Session `{session_name}` already exists with memory `{existing_session.memory.value}`. "
6060
f"Requested memory `{memory.value}` does not match."
6161
)
6262
else:
63-
create_details = self._create_session(session_name, memory, db_connection)
63+
create_details = self._create_session(session_name, memory.value, db_connection)
6464
session_id = create_details.id
6565

6666
wait_result = self._aura_api.wait_for_instance_running(session_id)
@@ -118,7 +118,7 @@ def _find_existing_session(self, session_name: str) -> Optional[InstanceSpecific
118118
return self._aura_api.list_instance(matched_instances[0].id)
119119

120120
def _create_session(
121-
self, session_name: str, memory: SessionMemory, db_connection: DbmsConnectionInfo
121+
self, session_name: str, memory: SessionMemoryValue, db_connection: DbmsConnectionInfo
122122
) -> InstanceCreateDetails:
123123
db_instance_id = AuraApi.extract_id(db_connection.uri)
124124
db_instance = self._aura_api.list_instance(db_instance_id)

graphdatascience/session/dedicated_sessions.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
1212
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
1313
from graphdatascience.session.session_info import SessionInfo
14-
from graphdatascience.session.session_sizes import SessionMemory
14+
from graphdatascience.session.session_sizes import SessionMemory, SessionMemoryValue
1515

1616

1717
class DedicatedSessions:
@@ -35,7 +35,7 @@ def estimate(
3535
ResourceWarning,
3636
)
3737

38-
return SessionMemory(estimation.recommended_size)
38+
return SessionMemory(SessionMemoryValue(estimation.recommended_size))
3939

4040
def get_or_create(
4141
self,
@@ -52,10 +52,10 @@ def get_or_create(
5252
# TODO configure session size (and check existing_session has same size)
5353
if existing_session:
5454
self._check_expiry_date(existing_session)
55-
self._check_memory_configuration(existing_session, memory)
55+
self._check_memory_configuration(existing_session, memory.value)
5656
session_id = existing_session.id
5757
else:
58-
create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory)
58+
create_details = self._create_session(session_name, dbid, db_connection.uri, password, memory.value)
5959
session_id = create_details.id
6060

6161
wait_result = self._aura_api.wait_for_session_running(session_id, dbid)
@@ -109,7 +109,7 @@ def _find_existing_session(self, session_name: str, dbid: str) -> Optional[Sessi
109109
return matched_sessions[0]
110110

111111
def _create_session(
112-
self, session_name: str, dbid: str, dburi: str, pwd: str, memory: SessionMemory
112+
self, session_name: str, dbid: str, dburi: str, pwd: str, memory: SessionMemoryValue
113113
) -> SessionDetails:
114114
db_instance = self._aura_api.list_instance(dbid)
115115
if not db_instance:
@@ -140,11 +140,13 @@ def _check_expiry_date(self, session: SessionDetails) -> None:
140140
if until_expiry < timedelta(days=1):
141141
raise Warning(f"Session `{session.name}` is expiring in less than a day.")
142142

143-
def _check_memory_configuration(self, existing_session: SessionDetails, requested_memory: SessionMemory) -> None:
143+
def _check_memory_configuration(
144+
self, existing_session: SessionDetails, requested_memory: SessionMemoryValue
145+
) -> None:
144146
if existing_session.memory != requested_memory:
145147
raise RuntimeError(
146148
f"Session `{existing_session.name}` exists with a different memory configuration. "
147-
f"Current: {existing_session.memory.value}, Requested: {requested_memory.value}."
149+
f"Current: {existing_session.memory}, Requested: {requested_memory}."
148150
)
149151

150152
@classmethod

graphdatascience/session/session_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from datetime import datetime
55
from typing import Optional
66

7-
from graphdatascience.session.session_sizes import SessionMemory
87
from graphdatascience.session.aura_api_responses import SessionDetails
8+
from graphdatascience.session.session_sizes import SessionMemoryValue
99

1010

1111
@dataclass(frozen=True)
@@ -19,7 +19,7 @@ class SessionInfo:
1919
"""
2020

2121
name: str
22-
memory: SessionMemory
22+
memory: SessionMemoryValue
2323

2424
@classmethod
2525
def from_session_details(cls, details: SessionDetails) -> ExtendedSessionInfo:

graphdatascience/session/session_sizes.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,53 @@
1+
from dataclasses import dataclass
12
from enum import Enum
23
from typing import List
34

45

6+
@dataclass(frozen=True)
7+
class SessionMemoryValue:
8+
value: str
9+
10+
def __str__(self) -> str:
11+
return self.value
12+
13+
@staticmethod
14+
def fromApiResponse(value: str) -> "SessionMemoryValue":
15+
"""
16+
Converts the string value from an API response to a SessionMemory enumeration value.
17+
18+
Args:
19+
value: The string value from the API response.
20+
21+
Returns:
22+
The SessionMemory enumeration value.
23+
24+
"""
25+
if value == "":
26+
raise ValueError("memory configuration cannot be empty")
27+
28+
return SessionMemoryValue(value.replace("Gi", "GB"))
29+
30+
531
class SessionMemory(Enum):
632
"""
733
Enumeration representing session main memory configurations.
834
"""
935

10-
m_4GB = "4GB"
11-
m_8GB = "8GB"
12-
m_16GB = "16GB"
13-
m_24GB = "24GB"
14-
m_32GB = "32GB"
15-
m_48GB = "48GB"
16-
m_64GB = "64GB"
17-
m_96GB = "96GB"
18-
m_128GB = "128GB"
19-
m_192GB = "192GB"
20-
m_256GB = "256GB"
21-
m_384GB = "384GB"
36+
m_4GB = SessionMemoryValue("4GB")
37+
m_8GB = SessionMemoryValue("8GB")
38+
m_16GB = SessionMemoryValue("16GB")
39+
m_24GB = SessionMemoryValue("24GB")
40+
m_32GB = SessionMemoryValue("32GB")
41+
m_48GB = SessionMemoryValue("48GB")
42+
m_64GB = SessionMemoryValue("64GB")
43+
m_96GB = SessionMemoryValue("96GB")
44+
m_128GB = SessionMemoryValue("128GB")
45+
m_192GB = SessionMemoryValue("192GB")
46+
m_256GB = SessionMemoryValue("256GB")
47+
m_384GB = SessionMemoryValue("384GB")
2248

2349
@classmethod
24-
def all_values(cls) -> List[str]:
50+
def all_values(cls) -> List[SessionMemoryValue]:
2551
"""
2652
All supported memory configurations.
2753
@@ -30,20 +56,3 @@ def all_values(cls) -> List[str]:
3056
3157
"""
3258
return [e.value for e in cls]
33-
34-
@staticmethod
35-
def fromApiResponse(value: str) -> "SessionMemory":
36-
"""
37-
Converts the string value from an API response to a SessionMemory enumeration value.
38-
39-
Args:
40-
value: The string value from the API response.
41-
42-
Returns:
43-
The SessionMemory enumeration value.
44-
45-
"""
46-
try:
47-
return SessionMemory(value.replace("Gi", "GB"))
48-
except ValueError:
49-
raise ValueError(f"Unsupported memory configuration: {value}")

graphdatascience/tests/unit/test_aura_api.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_create_session(requests_mock: Mocker) -> None:
3838
},
3939
)
4040

41-
result = api.create_session("name-0", "dbid-1", "pwd-2", SessionMemory.m_4GB)
41+
result = api.create_session("name-0", "dbid-1", "pwd-2", SessionMemory.m_4GB.value)
4242

4343
assert result == SessionDetails(
4444
id="id0",
@@ -47,7 +47,7 @@ def test_create_session(requests_mock: Mocker) -> None:
4747
instance_id="dbid-1",
4848
created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"),
4949
host="1.2.3.4",
50-
memory=SessionMemory.m_4GB,
50+
memory=SessionMemory.m_4GB.value,
5151
expiry_date=None,
5252
ttl=None,
5353
)
@@ -82,7 +82,7 @@ def test_list_session(requests_mock: Mocker) -> None:
8282
instance_id="dbid-1",
8383
created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"),
8484
host="1.2.3.4",
85-
memory=SessionMemory.m_4GB,
85+
memory=SessionMemory.m_4GB.value,
8686
expiry_date=TimeParser.fromisoformat("1977-01-01T00:00:00Z"),
8787
ttl=None,
8888
)
@@ -126,7 +126,7 @@ def test_list_sessions(requests_mock: Mocker) -> None:
126126
instance_id="dbid-1",
127127
created_at=TimeParser.fromisoformat("1970-01-01T00:00:00Z"),
128128
host="1.2.3.4",
129-
memory=SessionMemory.m_4GB,
129+
memory=SessionMemory.m_4GB.value,
130130
expiry_date=TimeParser.fromisoformat("1977-01-01T00:00:00Z"),
131131
ttl=None,
132132
)
@@ -137,7 +137,7 @@ def test_list_sessions(requests_mock: Mocker) -> None:
137137
status="Creating",
138138
instance_id="dbid-3",
139139
created_at=TimeParser.fromisoformat("2012-01-01T00:00:00Z"),
140-
memory=SessionMemory.m_8GB,
140+
memory=SessionMemory.m_8GB.value,
141141
host="foo.bar",
142142
expiry_date=None,
143143
ttl=None,
@@ -270,7 +270,7 @@ def test_delete_instance(requests_mock: Mocker) -> None:
270270
cloud_provider="",
271271
status="deleting",
272272
connection_url="",
273-
memory=SessionMemory.m_4GB,
273+
memory=SessionMemory.m_4GB.value,
274274
region="",
275275
type="",
276276
)
@@ -337,7 +337,7 @@ def test_create_instance(requests_mock: Mocker) -> None:
337337
},
338338
)
339339

340-
api.create_instance("name", SessionMemory.m_16GB, "gcp", "leipzig-1")
340+
api.create_instance("name", SessionMemory.m_16GB.value, "gcp", "leipzig-1")
341341

342342
requested_data = requests_mock.request_history[-1].json()
343343
assert requested_data["name"] == "name"
@@ -446,7 +446,7 @@ def test_list_instance_missing_memory_field(requests_mock: Mocker) -> None:
446446
result = api.list_instance("id0")
447447

448448
assert result and result.id == "a10fb995"
449-
assert result.memory == SessionMemory.m_16GB
449+
assert result.memory == SessionMemory.m_16GB.value
450450

451451

452452
def test_list_missing_instance(requests_mock: Mocker) -> None:
@@ -653,7 +653,7 @@ def test_parse_session_info() -> None:
653653
assert session_info == SessionDetails(
654654
id="test_id",
655655
name="test_session",
656-
memory=SessionMemory.m_4GB,
656+
memory=SessionMemory.m_4GB.value,
657657
instance_id="test_instance",
658658
status="running",
659659
host="a.b",
@@ -678,7 +678,7 @@ def test_parse_session_info_without_optionals() -> None:
678678
assert session_info == SessionDetails(
679679
id="test_id",
680680
name="test_session",
681-
memory=SessionMemory.m_16GB,
681+
memory=SessionMemory.m_16GB.value,
682682
instance_id="test_instance",
683683
host="a.b",
684684
status="running",

0 commit comments

Comments
 (0)