1111from graphdatascience .session .aura_graph_data_science import AuraGraphDataScience
1212from graphdatascience .session .dbms_connection_info import DbmsConnectionInfo
1313from graphdatascience .session .session_info import SessionInfo
14- from graphdatascience .session .session_sizes import SessionMemory
14+ from graphdatascience .session .session_sizes import SessionMemory , SessionMemoryValue
1515
1616
1717class DedicatedSessions :
@@ -35,7 +35,7 @@ def estimate(
3535 ResourceWarning ,
3636 )
3737
38- return SessionMemory (estimation .recommended_size )
38+ return SessionMemory (SessionMemoryValue ( estimation .recommended_size ) )
3939
4040 def get_or_create (
4141 self ,
@@ -52,9 +52,10 @@ def get_or_create(
5252 # TODO configure session size (and check existing_session has same size)
5353 if existing_session :
5454 self ._check_expiry_date (existing_session )
55+ self ._check_memory_configuration (existing_session , memory .value )
5556 session_id = existing_session .id
5657 else :
57- create_details = self ._create_session (session_name , dbid , db_connection .uri , password , memory )
58+ create_details = self ._create_session (session_name , dbid , db_connection .uri , password , memory . value )
5859 session_id = create_details .id
5960
6061 wait_result = self ._aura_api .wait_for_session_running (session_id , dbid )
@@ -108,7 +109,7 @@ def _find_existing_session(self, session_name: str, dbid: str) -> Optional[Sessi
108109 return matched_sessions [0 ]
109110
110111 def _create_session (
111- self , session_name : str , dbid : str , dburi : str , pwd : str , memory : SessionMemory
112+ self , session_name : str , dbid : str , dburi : str , pwd : str , memory : SessionMemoryValue
112113 ) -> SessionDetails :
113114 db_instance = self ._aura_api .list_instance (dbid )
114115 if not db_instance :
@@ -118,7 +119,7 @@ def _create_session(
118119 name = session_name ,
119120 dbid = dbid ,
120121 pwd = pwd ,
121- memory = memory . value ,
122+ memory = memory ,
122123 )
123124 return create_details
124125
@@ -139,6 +140,15 @@ def _check_expiry_date(self, session: SessionDetails) -> None:
139140 if until_expiry < timedelta (days = 1 ):
140141 raise Warning (f"Session `{ session .name } ` is expiring in less than a day." )
141142
143+ def _check_memory_configuration (
144+ self , existing_session : SessionDetails , requested_memory : SessionMemoryValue
145+ ) -> None :
146+ if existing_session .memory != requested_memory :
147+ raise RuntimeError (
148+ f"Session `{ existing_session .name } ` exists with a different memory configuration. "
149+ f"Current: { existing_session .memory } , Requested: { requested_memory } ."
150+ )
151+
142152 @classmethod
143153 def _fail_ambiguous_session (cls , session_name : str , sessions : List [SessionDetails ]) -> None :
144154 candidates = [i .id for i in sessions ]
0 commit comments