55import binascii
66from copy import deepcopy
77from io import BytesIO
8+ from typing import List , Optional , Tuple , Union , Set
89
9- def overlap (i1 , i2 ):
10+
11+ def overlap (i1 : Tuple [int , int ], i2 : Tuple [int , int ]) -> bool :
1012 return i1 [0 ] < i2 [1 ] and i1 [1 ] > i2 [0 ]
1113
12- def contains (i1 , i2 ) :
14+ def contains (i1 : Tuple [ int , int ], i2 : Tuple [ int , int ]) -> bool :
1315 return i2 [0 ] >= i1 [0 ] and i2 [1 ] <= i1 [1 ]
1416
1517class Gtid (object ):
16- """A mysql GTID is composed of a server-id and a set of right-open
18+ """
19+ A mysql GTID is composed of a server-id and a set of right-open
1720 intervals [a,b), and represent all transactions x that happened on
1821 server SID such as
1922
@@ -49,7 +52,7 @@ class Gtid(object):
4952 Exception: Adding a Gtid with a different SID.
5053 """
5154 @staticmethod
52- def parse_interval (interval ) :
55+ def parse_interval (interval : str ) -> Tuple [ int , int ] :
5356 """
5457 We parse a human-generated string here. So our end value b
5558 is incremented to conform to the internal representation format.
@@ -65,8 +68,9 @@ def parse_interval(interval):
6568 return (a , b + 1 )
6669
6770 @staticmethod
68- def parse (gtid ):
69- """Parse a GTID from mysql textual format.
71+ def parse (gtid : str ) -> Tuple [str , List [Tuple [int , int ]]]:
72+ """
73+ Parse a GTID from mysql textual format.
7074
7175 Raises:
7276 - ValueError: if GTID format is incorrect.
@@ -84,15 +88,15 @@ def parse(gtid):
8488
8589 return (sid , intervals_parsed )
8690
87- def __add_interval (self , itvl ) :
91+ def __add_interval (self , itvl : Tuple [ int , int ]) -> None :
8892 """
8993 Use the internal representation format and add it
9094 to our intervals, merging if required.
9195
9296 Raises:
9397 Exception: if Malformated interval or Overlapping interval
9498 """
95- new = []
99+ new : List [ Tuple [ int , int ]] = []
96100
97101 if itvl [0 ] > itvl [1 ]:
98102 raise Exception ('Malformed interval %s' % (itvl ,))
@@ -114,11 +118,13 @@ def __add_interval(self, itvl):
114118
115119 self .intervals = sorted (new + [itvl ])
116120
117- def __sub_interval (self , itvl ):
118- """Using the internal representation, remove an interval
121+ def __sub_interval (self , itvl : Tuple [int , int ]) -> None :
122+ """
123+ Using the internal representation, remove an interval
119124
120- Raises: Exception if itvl malformated"""
121- new = []
125+ Raises: Exception if itvl malformated
126+ """
127+ new : List [Tuple [int , int ]] = []
122128
123129 if itvl [0 ] > itvl [1 ]:
124130 raise Exception ('Malformed interval %s' % (itvl ,))
@@ -139,8 +145,9 @@ def __sub_interval(self, itvl):
139145
140146 self .intervals = new
141147
142- def __contains__ (self , other ):
143- """Test if other is contained within self.
148+ def __contains__ (self , other : 'Gtid' ) -> bool :
149+ """
150+ Test if other is contained within self.
144151 First we compare sid they must be equals.
145152
146153 Then we search if intervals from other are contained within
@@ -152,22 +159,22 @@ def __contains__(self, other):
152159 return all (any (contains (me , them ) for me in self .intervals )
153160 for them in other .intervals )
154161
155- def __init__ (self , gtid , sid = None , intervals = []):
156- if sid :
157- intervals = intervals
158- else :
162+ def __init__ (self , gtid : str , sid : Optional [str ] = None , intervals : Optional [List [Tuple [int , int ]]] = None ) -> None :
163+ if sid is None :
159164 sid , intervals = Gtid .parse (gtid )
160165
161166 self .sid = sid
162167 self .intervals = []
163168 for itvl in intervals :
164169 self .__add_interval (itvl )
165170
166- def __add__ (self , other ):
167- """Include the transactions of this gtid.
171+ def __add__ (self , other : 'Gtid' ) -> 'Gtid' :
172+ """
173+ Include the transactions of this gtid.
168174
169175 Raises:
170- Exception: if the attempted merge has different SID"""
176+ Exception: if the attempted merge has different SID
177+ """
171178 if self .sid != other .sid :
172179 raise Exception ('Attempt to merge different SID'
173180 '%s != %s' % (self .sid , other .sid ))
@@ -179,9 +186,10 @@ def __add__(self, other):
179186
180187 return result
181188
182- def __sub__ (self , other ):
183- """Remove intervals. Do not raise, if different SID simply
184- ignore"""
189+ def __sub__ (self , other : 'Gtid' ) -> 'Gtid' :
190+ """
191+ Remove intervals. Do not raise, if different SID simply ignore
192+ """
185193 result = deepcopy (self )
186194 if self .sid != other .sid :
187195 return result
@@ -191,27 +199,30 @@ def __sub__(self, other):
191199
192200 return result
193201
194- def __str__ (self ):
195- """We represent the human value here - a single number
196- for one transaction, or a closed interval (decrementing b)"""
202+ def __str__ (self ) -> str :
203+ """
204+ We represent the human value here - a single number
205+ for one transaction, or a closed interval (decrementing b)
206+ """
197207 return '%s:%s' % (self .sid ,
198208 ':' .join (('%d-%d' % (x [0 ], x [1 ]- 1 )) if x [0 ] + 1 != x [1 ]
199209 else str (x [0 ])
200210 for x in self .intervals ))
201211
202- def __repr__ (self ):
212+ def __repr__ (self ) -> str :
203213 return '<Gtid "%s">' % self
204214
205215 @property
206- def encoded_length (self ):
216+ def encoded_length (self ) -> int :
207217 return (16 + # sid
208218 8 + # n_intervals
209219 2 * # stop/start
210220 8 * # stop/start mark encoded as int64
211221 len (self .intervals ))
212222
213- def encode (self ):
214- """Encode a Gtid in binary
223+ def encode (self ) -> bytes :
224+ """
225+ Encode a Gtid in binary
215226 Bytes are in **little endian**.
216227
217228 Format:
@@ -251,8 +262,9 @@ def encode(self):
251262 return buffer
252263
253264 @classmethod
254- def decode (cls , payload ):
255- """Decode from binary a Gtid
265+ def decode (cls , payload : BytesIO ) -> 'Gtid' :
266+ """
267+ Decode from binary a Gtid
256268
257269 :param BytesIO payload to decode
258270 """
@@ -281,35 +293,35 @@ def decode(cls, payload):
281293 else '%d' % x
282294 for x in intervals ])))
283295
284- def __eq__ (self , other ) :
296+ def __eq__ (self , other : 'Gtid' ) -> bool :
285297 if other .sid != self .sid :
286298 return False
287299 return self .intervals == other .intervals
288300
289- def __lt__ (self , other ) :
301+ def __lt__ (self , other : 'Gtid' ) -> bool :
290302 if other .sid != self .sid :
291303 return self .sid < other .sid
292304 return self .intervals < other .intervals
293305
294- def __le__ (self , other ) :
306+ def __le__ (self , other : 'Gtid' ) -> bool :
295307 if other .sid != self .sid :
296308 return self .sid <= other .sid
297309 return self .intervals <= other .intervals
298310
299- def __gt__ (self , other ) :
311+ def __gt__ (self , other : 'Gtid' ) -> bool :
300312 if other .sid != self .sid :
301313 return self .sid > other .sid
302314 return self .intervals > other .intervals
303315
304- def __ge__ (self , other ) :
316+ def __ge__ (self , other : 'Gtid' ) -> bool :
305317 if other .sid != self .sid :
306318 return self .sid >= other .sid
307319 return self .intervals >= other .intervals
308320
309321
310322class GtidSet (object ):
311323 """Represents a set of Gtid"""
312- def __init__ (self , gtid_set ) :
324+ def __init__ (self , gtid_set : Optional [ Union [ None , str , Set [ Gtid ], List [ Gtid ], Gtid ]] = None ) -> None :
313325 """
314326 Construct a GtidSet initial state depends of the nature of `gtid_set` param.
315327
@@ -325,21 +337,21 @@ def __init__(self, gtid_set):
325337 - Exception: if Gtid interval are either malformated or overlapping
326338 """
327339
328- def _to_gtid (element ) :
340+ def _to_gtid (element : str ) -> Gtid :
329341 if isinstance (element , Gtid ):
330342 return element
331343 return Gtid (element .strip (' \n ' ))
332344
333345 if not gtid_set :
334- self .gtids = []
346+ self .gtids : List [ Gtid ] = []
335347 elif isinstance (gtid_set , (list , set )):
336- self .gtids = [_to_gtid (x ) for x in gtid_set ]
348+ self .gtids : List [ Gtid ] = [_to_gtid (x ) for x in gtid_set ]
337349 else :
338- self .gtids = [Gtid (x .strip (' \n ' )) for x in gtid_set .split (',' )]
350+ self .gtids : List [ Gtid ] = [Gtid (x .strip (' \n ' )) for x in gtid_set .split (',' )]
339351
340- def merge_gtid (self , gtid ) :
352+ def merge_gtid (self , gtid : Gtid ) -> None :
341353 """Insert a Gtid in current GtidSet."""
342- new_gtids = []
354+ new_gtids : List [ Gtid ] = []
343355 for existing in self .gtids :
344356 if existing .sid == gtid .sid :
345357 new_gtids .append (existing + gtid )
@@ -349,7 +361,7 @@ def merge_gtid(self, gtid):
349361 new_gtids .append (gtid )
350362 self .gtids = new_gtids
351363
352- def __contains__ (self , other ) :
364+ def __contains__ (self , other : Union [ 'GtidSet' , Gtid ]) -> bool :
353365 """
354366 Test if self contains other, could be a GtidSet or a Gtid.
355367
@@ -363,7 +375,7 @@ def __contains__(self, other):
363375 return any (other in x for x in self .gtids )
364376 raise NotImplementedError
365377
366- def __add__ (self , other ) :
378+ def __add__ (self , other : Union [ 'GtidSet' , Gtid ]) -> 'GtidSet' :
367379 """
368380 Merge current instance with an other GtidSet or with a Gtid alone.
369381
@@ -384,22 +396,23 @@ def __add__(self, other):
384396
385397 raise NotImplementedError
386398
387- def __str__ (self ):
399+ def __str__ (self ) -> str :
388400 """
389401 Returns a comma separated string of gtids.
390402 """
391403 return ',' .join (str (x ) for x in self .gtids )
392404
393- def __repr__ (self ):
405+ def __repr__ (self ) -> str :
394406 return '<GtidSet %r>' % self .gtids
395407
396408 @property
397- def encoded_length (self ):
409+ def encoded_length (self ) -> int :
398410 return (8 + # n_sids
399411 sum (x .encoded_length for x in self .gtids ))
400412
401- def encoded (self ):
402- """Encode a GtidSet in binary
413+ def encoded (self ) -> bytes :
414+ """
415+ Encode a GtidSet in binary
403416 Bytes are in **little endian**.
404417
405418 - `n_sid`: u64 is the number of Gtid to read
@@ -421,8 +434,9 @@ def encoded(self):
421434 encode = encoded
422435
423436 @classmethod
424- def decode (cls , payload ):
425- """Decode a GtidSet from binary.
437+ def decode (cls , payload : BytesIO ) -> 'GtidSet' :
438+ """
439+ Decode a GtidSet from binary.
426440
427441 :param BytesIO payload to decode
428442 """
@@ -432,5 +446,5 @@ def decode(cls, payload):
432446
433447 return cls ([Gtid .decode (payload ) for _ in range (0 , n_sid )])
434448
435- def __eq__ (self , other ) :
449+ def __eq__ (self , other : 'GtidSet' ) -> bool :
436450 return self .gtids == other .gtids
0 commit comments