1818# specific language governing permissions and limitations
1919# under the License.
2020#
21+ from abc import ABC , abstractmethod
22+ from time import sleep
2123from typing import Dict , Optional , Set , TYPE_CHECKING
2224
2325import typedb_protocol .cluster .cluster_database_pb2 as cluster_database_proto
2426
2527from typedb .api .database import ClusterDatabase
28+ from typedb .common .exception import TypeDBClientException , UNABLE_TO_CONNECT , CLUSTER_REPLICA_NOT_PRIMARY , \
29+ CLUSTER_UNABLE_TO_CONNECT
30+ from typedb .common .rpc .request_builder import cluster_database_manager_get_req
2631from typedb .core .database import _CoreDatabase
2732
2833if TYPE_CHECKING :
29- from typedb .cluster .database_manager import _ClusterDatabaseManager
34+ from typedb .cluster .client import _ClusterClient
3035
3136
3237class _ClusterDatabase (ClusterDatabase ):
3338
34- def __init__ (self , database : str , cluster_database_mgr : "_ClusterDatabaseManager " ):
39+ def __init__ (self , database : str , client : "_ClusterClient " ):
3540 self ._name = database
36- self ._database_mgr = cluster_database_mgr
41+ self ._client = client
3742 self ._databases : Dict [str , _CoreDatabase ] = {}
3843 self ._replicas : Set ["_ClusterDatabase.Replica" ] = set ()
39- for address in cluster_database_mgr .database_mgrs ():
40- core_database_mgr = cluster_database_mgr .database_mgrs ()[address ]
44+ cluster_db_mgr = client .databases ()
45+ for address in cluster_db_mgr .database_mgrs ():
46+ core_database_mgr = cluster_db_mgr .database_mgrs ()[address ]
4147 self ._databases [address ] = _CoreDatabase (core_database_mgr .stub (), name = database )
4248
4349 @staticmethod
44- def of (proto_db : cluster_database_proto .ClusterDatabase , cluster_database_mgr : "_ClusterDatabaseManager " ) -> "_ClusterDatabase" :
50+ def of (proto_db : cluster_database_proto .ClusterDatabase , client : "_ClusterClient " ) -> "_ClusterDatabase" :
4551 assert proto_db .replicas
4652 database : str = proto_db .name
47- database_cluster_rpc = _ClusterDatabase (database , cluster_database_mgr )
53+ database_cluster_rpc = _ClusterDatabase (database , client )
4854 for proto_replica in proto_db .replicas :
4955 database_cluster_rpc .replicas ().add (_ClusterDatabase .Replica .of (proto_replica , database_cluster_rpc ))
5056 print ("Discovered database cluster: %s" % database_cluster_rpc )
@@ -57,9 +63,8 @@ def schema(self) -> str:
5763 return next (iter (self ._databases .values ())).schema ()
5864
5965 def delete (self ) -> None :
60- for address in self ._databases :
61- if self ._database_mgr .database_mgrs ()[address ].contains (self ._name ):
62- self ._databases [address ].delete ()
66+ delete_db_task = _DeleteDatabaseFailsafeTask (self ._client , self ._name , self ._databases )
67+ delete_db_task .run_primary_replica ()
6368
6469 def replicas (self ):
6570 return self ._replicas
@@ -142,3 +147,99 @@ def __hash__(self):
142147
143148 def __str__ (self ):
144149 return "%s/%s" % (self ._address , self ._database )
150+
151+
152+ # This class has to live here because of circular class creation between ClusterDatabase and FailsafeTask
153+ class _FailsafeTask (ABC ):
154+
155+ PRIMARY_REPLICA_TASK_MAX_RETRIES = 10
156+ FETCH_REPLICAS_MAX_RETRIES = 10
157+ WAIT_FOR_PRIMARY_REPLICA_SELECTION_SECONDS : float = 2
158+
159+ def __init__ (self , client : "_ClusterClient" , database : str ):
160+ self .client = client
161+ self .database = database
162+
163+ @abstractmethod
164+ def run (self , replica : "_ClusterDatabase.Replica" ):
165+ pass
166+
167+ def rerun (self , replica : "_ClusterDatabase.Replica" ):
168+ return self .run (replica )
169+
170+ def run_primary_replica (self ):
171+ if self .database not in self .client .database_by_name () or not self .client .database_by_name ()[self .database ].primary_replica ():
172+ self ._seek_primary_replica ()
173+ replica = self .client .database_by_name ()[self .database ].primary_replica ()
174+ retries = 0
175+ while True :
176+ try :
177+ return self .run (replica ) if retries == 0 else self .rerun (replica )
178+ except TypeDBClientException as e :
179+ if e .error_message in [CLUSTER_REPLICA_NOT_PRIMARY , UNABLE_TO_CONNECT ]:
180+ print ("Unable to open a session or transaction, retrying in 2s... %s" % str (e ))
181+ sleep (self .WAIT_FOR_PRIMARY_REPLICA_SELECTION_SECONDS )
182+ replica = self ._seek_primary_replica ()
183+ else :
184+ raise e
185+ retries += 1
186+ if retries > self .PRIMARY_REPLICA_TASK_MAX_RETRIES :
187+ raise self ._cluster_not_available_exception ()
188+
189+ def run_any_replica (self ):
190+ if self .database in self .client .database_by_name ():
191+ cluster_database = self .client .database_by_name ()[self .database ]
192+ else :
193+ cluster_database = self ._fetch_database_replicas ()
194+
195+ replicas = [cluster_database .preferred_replica ()] + [replica for replica in cluster_database .replicas () if not replica .is_preferred ()]
196+ retries = 0
197+ for replica in replicas :
198+ try :
199+ return self .run (replica ) if retries == 0 else self .rerun (replica )
200+ except TypeDBClientException as e :
201+ if e .error_message is UNABLE_TO_CONNECT :
202+ print ("Unable to open a session or transaction to %s. Attempting next replica. %s" % (str (replica .replica_id ()), str (e )))
203+ else :
204+ raise e
205+ retries += 1
206+ raise self ._cluster_not_available_exception ()
207+
208+ def _seek_primary_replica (self ) -> "_ClusterDatabase.Replica" :
209+ retries = 0
210+ while retries < self .FETCH_REPLICAS_MAX_RETRIES :
211+ cluster_database = self ._fetch_database_replicas ()
212+ if cluster_database .primary_replica ():
213+ return cluster_database .primary_replica ()
214+ else :
215+ sleep (self .WAIT_FOR_PRIMARY_REPLICA_SELECTION_SECONDS )
216+ retries += 1
217+ raise self ._cluster_not_available_exception ()
218+
219+ def _fetch_database_replicas (self ) -> "_ClusterDatabase" :
220+ for server_address in self .client .cluster_members ():
221+ try :
222+ print ("Fetching replica info from %s" % server_address )
223+ res = self .client .stub (server_address ).databases_get (cluster_database_manager_get_req (self .database ))
224+ cluster_database = _ClusterDatabase .of (res .database , self .client )
225+ self .client .database_by_name ()[self .database ] = cluster_database
226+ return cluster_database
227+ except TypeDBClientException as e :
228+ if e .error_message is UNABLE_TO_CONNECT :
229+ print ("Unable to fetch replica info for database '%s' from %s. Attempting next address. %s" % (self .database , server_address , str (e )))
230+ else :
231+ raise e
232+ raise self ._cluster_not_available_exception ()
233+
234+ def _cluster_not_available_exception (self ) -> TypeDBClientException :
235+ return TypeDBClientException .of (CLUSTER_UNABLE_TO_CONNECT , str ([str (addr ) for addr in self .client .cluster_members ()]))
236+
237+
238+ class _DeleteDatabaseFailsafeTask (_FailsafeTask ):
239+
240+ def __init__ (self , client : "_ClusterClient" , database : str , databases : Dict [str , _CoreDatabase ]):
241+ super ().__init__ (client , database )
242+ self .databases = databases
243+
244+ def run (self , replica : _ClusterDatabase .Replica ):
245+ self .databases .get (replica .address ()).delete ()
0 commit comments