4242
4343# Signature bytes for each message type
4444INIT = b"\x01 " # 0000 0001 // INIT <user_agent>
45- ACK_FAILURE = b"\x0F " # 0000 1111 // ACK_FAILURE
45+ RESET = b"\x0F " # 0000 1111 // RESET
4646RUN = b"\x10 " # 0001 0000 // RUN <statement> <parameters>
4747DISCARD_ALL = b"\x2F " # 0010 1111 // DISCARD *
4848PULL_ALL = b"\x3F " # 0011 1111 // PULL *
5656
5757message_names = {
5858 INIT : "INIT" ,
59- ACK_FAILURE : "ACK_FAILURE " ,
59+ RESET : "RESET " ,
6060 RUN : "RUN" ,
6161 DISCARD_ALL : "DISCARD_ALL" ,
6262 PULL_ALL : "PULL_ALL" ,
@@ -169,14 +169,6 @@ def chunk_reader(self):
169169 data = self ._recv (chunk_size )
170170 yield data
171171
172- def close (self ):
173- """ Shut down and close the connection.
174- """
175- if __debug__ : log_info ("~~ [CLOSE]" )
176- socket = self .socket
177- socket .shutdown (SHUT_RDWR )
178- socket .close ()
179-
180172
181173class Response (object ):
182174 """ Subscriber object for a full response (zero or
@@ -200,12 +192,6 @@ def on_ignored(self, metadata=None):
200192 pass
201193
202194
203- class AckFailureResponse (Response ):
204-
205- def on_failure (self , metadata ):
206- raise ProtocolError ("Could not acknowledge failure" )
207-
208-
209195class Connection (object ):
210196 """ Server connection through which all protocol messages
211197 are sent and received. This class is designed for protocol
@@ -215,9 +201,11 @@ class Connection(object):
215201 """
216202
217203 def __init__ (self , sock , ** config ):
204+ self .defunct = False
218205 self .channel = ChunkChannel (sock )
219206 self .packer = Packer (self .channel )
220207 self .responses = deque ()
208+ self .closed = False
221209
222210 # Determine the user agent and ensure it is a Unicode value
223211 user_agent = config .get ("user_agent" , DEFAULT_USER_AGENT )
@@ -235,8 +223,15 @@ def on_failure(metadata):
235223 while not response .complete :
236224 self .fetch_next ()
237225
226+ def __del__ (self ):
227+ self .close ()
228+
238229 def append (self , signature , fields = (), response = None ):
239230 """ Add a message to the outgoing queue.
231+
232+ :arg signature: the signature of the message
233+ :arg fields: the fields of the message as a tuple
234+ :arg response: a response object to handle callbacks
240235 """
241236 if __debug__ :
242237 log_info ("C: %s %s" , message_names [signature ], " " .join (map (repr , fields )))
@@ -247,42 +242,75 @@ def append(self, signature, fields=(), response=None):
247242 self .channel .flush (end_of_message = True )
248243 self .responses .append (response )
249244
245+ def reset (self ):
246+ """ Add a RESET message to the outgoing queue, send
247+ it and consume all remaining messages.
248+ """
249+ response = Response (self )
250+
251+ def on_failure (metadata ):
252+ raise ProtocolError ("Reset failed" )
253+
254+ response .on_failure = on_failure
255+
256+ self .append (RESET , response = response )
257+ self .send ()
258+ fetch_next = self .fetch_next
259+ while not response .complete :
260+ fetch_next ()
261+
250262 def send (self ):
251263 """ Send all queued messages to the server.
252264 """
265+ if self .closed :
266+ raise ProtocolError ("Cannot write to a closed connection" )
267+ if self .defunct :
268+ raise ProtocolError ("Cannot write to a defunct connection" )
253269 self .channel .send ()
254270
255271 def fetch_next (self ):
256272 """ Receive exactly one message from the server.
257273 """
274+ if self .closed :
275+ raise ProtocolError ("Cannot read from a closed connection" )
276+ if self .defunct :
277+ raise ProtocolError ("Cannot read from a defunct connection" )
258278 raw = BytesIO ()
259279 unpack = Unpacker (raw ).unpack
260- raw .writelines (self .channel .chunk_reader ())
261-
280+ try :
281+ raw .writelines (self .channel .chunk_reader ())
282+ except ProtocolError :
283+ self .defunct = True
284+ self .close ()
285+ raise
262286 # Unpack from the raw byte stream and call the relevant message handler(s)
263287 raw .seek (0 )
264288 response = self .responses [0 ]
265289 for signature , fields in unpack ():
266290 if __debug__ :
267291 log_info ("S: %s %s" , message_names [signature ], " " .join (map (repr , fields )))
292+ if signature in SUMMARY :
293+ response .complete = True
294+ self .responses .popleft ()
295+ if signature == FAILURE :
296+ self .reset ()
268297 handler_name = "on_%s" % message_names [signature ].lower ()
269298 try :
270299 handler = getattr (response , handler_name )
271300 except AttributeError :
272301 pass
273302 else :
274303 handler (* fields )
275- if signature in SUMMARY :
276- response .complete = True
277- self .responses .popleft ()
278- if signature == FAILURE :
279- self .append (ACK_FAILURE , response = AckFailureResponse (self ))
280304 raw .close ()
281305
282306 def close (self ):
283- """ Shut down and close the connection.
307+ """ Close the connection.
284308 """
285- self .channel .close ()
309+ if not self .closed :
310+ if __debug__ :
311+ log_info ("~~ [CLOSE]" )
312+ self .channel .socket .close ()
313+ self .closed = True
286314
287315
288316def connect (host , port = None , ** config ):
0 commit comments