Skip to content

Commit 9f0e418

Browse files
author
Arthur Gautier
committed
Merge branch 'gtid-intervals' into 'master'
2 parents 7591950 + 12fa003 commit 9f0e418

File tree

2 files changed

+179
-31
lines changed

2 files changed

+179
-31
lines changed

pymysqlreplication/gtid.py

Lines changed: 175 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,60 @@
55
import binascii
66
from io import BytesIO
77

8+
def overlap(i1, i2):
9+
return i1[0] < i2[1] and i1[1] > i2[0]
10+
11+
def contains(i1, i2):
12+
return i2[0] >= i1[0] and i2[1] <= i1[1]
813

914
class Gtid(object):
15+
"""A mysql GTID is composed of a server-id and a set of right-open
16+
intervals [a,b), and represent all transactions x that happened on
17+
server SID such as
18+
19+
x <= a < b
20+
21+
The human representation of it, though, is either represented by a
22+
single transaction number A=a (when only one transaction is covered,
23+
ie b = a+1)
24+
25+
SID:A
26+
27+
Or a closed interval [A,B] for at least two transactions (note, in that
28+
case, that b=B+1)
29+
30+
SID:A-B
31+
32+
We can also have a mix of ranges for a given SID:
33+
SID:1-2:4:6-74
34+
35+
For convenience, a Gtid accepts adding Gtid's to it and will merge
36+
the existing interval representation. Adding TXN 3 to the human
37+
representation above would produce:
38+
39+
SID:1-4:6-74
40+
41+
and adding 5 to this new result:
42+
43+
SID:1-74
44+
45+
Adding an already present transaction number (one that overlaps) will
46+
raise an exception.
47+
48+
Adding a Gtid with a different SID will raise an exception.
49+
"""
1050
@staticmethod
1151
def parse_interval(interval):
52+
"""
53+
We parse a human-generated string here. So our end value b
54+
is incremented to conform to the internal representation format.
55+
"""
1256
m = re.search('^([0-9]+)(?:-([0-9]+))?$', interval)
1357
if not m:
1458
raise ValueError('GTID format is incorrect: %r' % (interval, ))
15-
if not m.group(2):
16-
return (int(m.group(1)))
17-
else:
18-
return (int(m.group(1)), int(m.group(2)))
59+
a = int(m.group(1))
60+
b = int(m.group(2) or a)
61+
return (a, b+1)
1962

2063
@staticmethod
2164
def parse(gtid):
@@ -32,16 +75,111 @@ def parse(gtid):
3275

3376
return (sid, intervals_parsed)
3477

35-
def __init__(self, gtid):
36-
self.sid = None
78+
def __add_interval(self, itvl):
79+
"""
80+
Use the internal representation format and add it
81+
to our intervals, merging if required.
82+
"""
83+
new = []
84+
85+
if itvl[0] > itvl[1]:
86+
raise Exception('Malformed interval %s' % (itvl,))
87+
88+
if any(overlap(x, itvl) for x in self.intervals):
89+
raise Exception('Overlapping interval %s' % (itvl,))
90+
91+
## Merge: arrange interval to fit existing set
92+
for existing in sorted(self.intervals):
93+
if itvl[0] == existing[1]:
94+
itvl = (existing[0], itvl[1])
95+
continue
96+
97+
if itvl[1] == existing[0]:
98+
itvl = (itvl[0], existing[1])
99+
continue
100+
101+
new.append(existing)
102+
103+
self.intervals = sorted(new + [itvl])
104+
105+
def __sub_interval(self, itvl):
106+
"""Using the internal representation, remove an interval"""
107+
new = []
108+
109+
if itvl[0] > itvl[1]:
110+
raise Exception('Malformed interval %s' % (itvl,))
111+
112+
if not any(overlap(x, itvl) for x in self.intervals):
113+
# No raise
114+
return
115+
116+
## Merge: arrange existing set around interval
117+
for existing in sorted(self.intervals):
118+
if overlap(existing, itvl):
119+
if existing[0] < itvl[0]:
120+
new.append((existing[0], itvl[0]))
121+
if existing[1] > itvl[1]:
122+
new.append((itvl[1], existing[1]))
123+
else:
124+
new.append(existing)
125+
126+
self.intervals = new
127+
128+
def __contains__(self, other):
129+
if other.sid != self.sid:
130+
return False
131+
132+
return all(any(contains(me, them) for me in self.intervals)
133+
for them in other.intervals)
134+
135+
def __init__(self, gtid, sid=None, intervals=[]):
136+
if sid:
137+
intervals = intervals
138+
else:
139+
sid, intervals = Gtid.parse(gtid)
140+
141+
self.sid = sid
37142
self.intervals = []
143+
for itvl in intervals:
144+
self.__add_interval(itvl)
145+
146+
def __add__(self, other):
147+
"""Include the transactions of this gtid. Raise if the
148+
attempted merge has different SID"""
149+
if self.sid != other.sid:
150+
raise Exception('Attempt to merge different SID'
151+
'%s != %s' % (self.sid, other.sid))
152+
153+
result = Gtid(str(self))
154+
155+
for itvl in other.intervals:
156+
result.__add_interval(itvl)
38157

39-
self.sid, self.intervals = Gtid.parse(gtid)
158+
return result
159+
160+
def __sub__(self, other):
161+
"""Remove intervals. Do not raise, if different SID simply
162+
ignore"""
163+
result = Gtid(str(self))
164+
if self.sid != other.sid:
165+
return result
166+
167+
for itvl in other.intervals:
168+
result.__sub_interval(itvl)
169+
170+
return result
171+
172+
def __cmp__(self, other):
173+
if other.sid != self.sid:
174+
return cmp(self.sid, other.sid)
175+
return cmp(self.intervals, other.intervals)
40176

41177
def __str__(self):
178+
"""We represent the human value here - a single number
179+
for one transaction, or a closed interval (decrementing b)"""
42180
return '%s:%s' % (self.sid,
43-
':'.join(('%d-%s' % x) if isinstance(x, tuple)
44-
else str(x)
181+
':'.join(('%d-%d' % (x[0], x[1]-1)) if x[0] +1 != x[1]
182+
else str(x[0])
45183
for x in self.intervals))
46184

47185
def __repr__(self):
@@ -63,21 +201,10 @@ def encode(self):
63201
buffer += struct.pack('<Q', len(self.intervals))
64202

65203
for interval in self.intervals:
66-
if isinstance(interval, tuple):
67-
# Do we have both a start and a stop position
68-
# Start position
69-
buffer += struct.pack('<Q', interval[0])
70-
# Stop position
71-
buffer += struct.pack('<Q', interval[1])
72-
else:
73-
# If we only have a start position
74-
# Like c63b8356-d74e-4870-8150-70eca127beb1:1,
75-
# the stop position is start + 1
76-
77-
# Start position
78-
buffer += struct.pack('<Q', interval)
79-
# Stop position
80-
buffer += struct.pack('<Q', interval + 1)
204+
# Start position
205+
buffer += struct.pack('<Q', interval[0])
206+
# Stop position
207+
buffer += struct.pack('<Q', interval[1])
81208

82209
return buffer
83210

@@ -100,10 +227,7 @@ def decode(cls, payload):
100227
intervals = []
101228
for i in range(0, n_intervals):
102229
start, end = struct.unpack('<QQ', payload.read(16))
103-
if end == start + 1:
104-
intervals.append(start)
105-
else:
106-
intervals.append((start, end))
230+
intervals.append((start, end-1))
107231

108232
return cls('%s:%s' % (sid.decode('ascii'), ':'.join([
109233
'%d-%d' % x
@@ -126,6 +250,29 @@ def _to_gtid(element):
126250
else:
127251
self.gtids = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')]
128252

253+
def merge_gtid(self, gtid):
254+
new_gtids = []
255+
for existing in self.gtids:
256+
if existing.sid == gtid.sid:
257+
new_gtids.append(existing + gtid)
258+
else:
259+
new_gtids.append(existing)
260+
if gtid.sid not in (x.sid for x in new_gtids):
261+
new_gtids.append(gtid)
262+
self.gtids = new_gtids
263+
264+
def __contains__(self, other):
265+
if isinstance(other, Gtid):
266+
return any(other in x for x in self.gtids)
267+
raise NotImplementedError
268+
269+
def __add__(self, other):
270+
if isinstance(other, Gtid):
271+
new = GtidSet(self.gtids)
272+
new.merge_gtid(other)
273+
return new
274+
raise NotImplementedError
275+
129276
def __str__(self):
130277
return ','.join(str(x) for x in self.gtids)
131278

pymysqlreplication/tests/test_basic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -810,11 +810,12 @@ def test_position_gtid(self):
810810
query = "COMMIT;"
811811
self.execute(query)
812812

813-
query = "CREATE TABLE test2 (id INT NOT NULL, data VARCHAR (50) NOT NULL, PRIMARY KEY (id))"
814-
self.execute(query)
815813
query = "SELECT @@global.gtid_executed;"
816814
gtid = self.execute(query).fetchone()[0]
817815

816+
query = "CREATE TABLE test2 (id INT NOT NULL, data VARCHAR (50) NOT NULL, PRIMARY KEY (id))"
817+
self.execute(query)
818+
818819
self.stream.close()
819820
self.stream = BinLogStreamReader(
820821
self.database, server_id=1024, blocking=True, auto_position=gtid,
@@ -845,7 +846,7 @@ def test_gtidset_representation_newline(self):
845846
myset = GtidSet(mysql_repr)
846847
self.assertEqual(str(myset), set_repr)
847848

848-
def test_gtidset_representation(self):
849+
def test_gtidset_representation_payload(self):
849850
set_repr = '57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56,' \
850851
'4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20'
851852

0 commit comments

Comments
 (0)