@@ -287,21 +287,26 @@ def test_simple_producer(self):
287287 producer = SimpleProducer (self .client )
288288 resp = producer .send_messages (self .topic , "one" , "two" )
289289
290- # Will go to partition 0
290+ partition_for_first_batch = resp [0 ].partition
291+
291292 self .assertEquals (len (resp ), 1 )
292293 self .assertEquals (resp [0 ].error , 0 )
293294 self .assertEquals (resp [0 ].offset , 0 ) # offset of first msg
294295
295- # Will go to partition 1
296+ # ensure this partition is different from the first partition
296297 resp = producer .send_messages (self .topic , "three" )
298+ partition_for_second_batch = resp [0 ].partition
299+ self .assertNotEquals (partition_for_first_batch , partition_for_second_batch )
300+
297301 self .assertEquals (len (resp ), 1 )
298302 self .assertEquals (resp [0 ].error , 0 )
299303 self .assertEquals (resp [0 ].offset , 0 ) # offset of first msg
300304
301- fetch1 = FetchRequest (self .topic , 0 , 0 , 1024 )
302- fetch2 = FetchRequest (self .topic , 1 , 0 , 1024 )
303- fetch_resp1 , fetch_resp2 = self .client .send_fetch_request ([fetch1 ,
304- fetch2 ])
305+ fetch_requests = (
306+ FetchRequest (self .topic , partition_for_first_batch , 0 , 1024 ),
307+ FetchRequest (self .topic , partition_for_second_batch , 0 , 1024 ),
308+ )
309+ fetch_resp1 , fetch_resp2 = self .client .send_fetch_request (fetch_requests )
305310 self .assertEquals (fetch_resp1 .error , 0 )
306311 self .assertEquals (fetch_resp1 .highwaterMark , 2 )
307312 messages = list (fetch_resp1 .messages )
@@ -314,11 +319,12 @@ def test_simple_producer(self):
314319 self .assertEquals (len (messages ), 1 )
315320 self .assertEquals (messages [0 ].message .value , "three" )
316321
317- # Will go to partition 0
322+ # Will go to same partition as first batch
318323 resp = producer .send_messages (self .topic , "four" , "five" )
319324 self .assertEquals (len (resp ), 1 )
320325 self .assertEquals (resp [0 ].error , 0 )
321326 self .assertEquals (resp [0 ].offset , 2 ) # offset of first msg
327+ self .assertEquals (resp [0 ].partition , partition_for_first_batch )
322328
323329 producer .stop ()
324330
@@ -396,14 +402,25 @@ def test_acks_none(self):
396402 resp = producer .send_messages (self .topic , "one" )
397403 self .assertEquals (len (resp ), 0 )
398404
399- fetch = FetchRequest (self .topic , 0 , 0 , 1024 )
400- fetch_resp = self .client .send_fetch_request ([fetch ])
405+ # fetch from both partitions
406+ fetch_requests = (
407+ FetchRequest (self .topic , 0 , 0 , 1024 ),
408+ FetchRequest (self .topic , 1 , 0 , 1024 ),
409+ )
410+ fetch_resps = self .client .send_fetch_request (fetch_requests )
401411
402- self .assertEquals (fetch_resp [0 ].error , 0 )
403- self .assertEquals (fetch_resp [0 ].highwaterMark , 1 )
404- self .assertEquals (fetch_resp [0 ].partition , 0 )
412+ # determine which partition was selected (due to random round-robin)
413+ published_to_resp = max (fetch_resps , key = lambda x : x .highwaterMark )
414+ not_published_to_resp = min (fetch_resps , key = lambda x : x .highwaterMark )
415+ self .assertNotEquals (published_to_resp .partition , not_published_to_resp .partition )
405416
406- messages = list (fetch_resp [0 ].messages )
417+ self .assertEquals (published_to_resp .error , 0 )
418+ self .assertEquals (published_to_resp .highwaterMark , 1 )
419+
420+ self .assertEquals (not_published_to_resp .error , 0 )
421+ self .assertEquals (not_published_to_resp .highwaterMark , 0 )
422+
423+ messages = list (published_to_resp .messages )
407424 self .assertEquals (len (messages ), 1 )
408425 self .assertEquals (messages [0 ].message .value , "one" )
409426
@@ -415,12 +432,14 @@ def test_acks_local_write(self):
415432 resp = producer .send_messages (self .topic , "one" )
416433 self .assertEquals (len (resp ), 1 )
417434
418- fetch = FetchRequest (self .topic , 0 , 0 , 1024 )
435+ partition = resp [0 ].partition
436+
437+ fetch = FetchRequest (self .topic , partition , 0 , 1024 )
419438 fetch_resp = self .client .send_fetch_request ([fetch ])
420439
421440 self .assertEquals (fetch_resp [0 ].error , 0 )
422441 self .assertEquals (fetch_resp [0 ].highwaterMark , 1 )
423- self .assertEquals (fetch_resp [0 ].partition , 0 )
442+ self .assertEquals (fetch_resp [0 ].partition , partition )
424443
425444 messages = list (fetch_resp [0 ].messages )
426445 self .assertEquals (len (messages ), 1 )
@@ -435,12 +454,14 @@ def test_acks_cluster_commit(self):
435454 resp = producer .send_messages (self .topic , "one" )
436455 self .assertEquals (len (resp ), 1 )
437456
438- fetch = FetchRequest (self .topic , 0 , 0 , 1024 )
457+ partition = resp [0 ].partition
458+
459+ fetch = FetchRequest (self .topic , partition , 0 , 1024 )
439460 fetch_resp = self .client .send_fetch_request ([fetch ])
440461
441462 self .assertEquals (fetch_resp [0 ].error , 0 )
442463 self .assertEquals (fetch_resp [0 ].highwaterMark , 1 )
443- self .assertEquals (fetch_resp [0 ].partition , 0 )
464+ self .assertEquals (fetch_resp [0 ].partition , partition )
444465
445466 messages = list (fetch_resp [0 ].messages )
446467 self .assertEquals (len (messages ), 1 )
@@ -456,17 +477,31 @@ def test_async_simple_producer(self):
456477 # Give it some time
457478 time .sleep (2 )
458479
459- fetch = FetchRequest (self .topic , 0 , 0 , 1024 )
460- fetch_resp = self .client .send_fetch_request ([fetch ])
480+ # fetch from both partitions
481+ fetch_requests = (
482+ FetchRequest (self .topic , 0 , 0 , 1024 ),
483+ FetchRequest (self .topic , 1 , 0 , 1024 ),
484+ )
485+ fetch_resps = self .client .send_fetch_request (fetch_requests )
461486
462- self .assertEquals (fetch_resp [0 ].error , 0 )
463- self .assertEquals (fetch_resp [0 ].highwaterMark , 1 )
464- self .assertEquals (fetch_resp [0 ].partition , 0 )
487+ # determine which partition was selected (due to random round-robin)
488+ published_to_resp = max (fetch_resps , key = lambda x : x .highwaterMark )
489+ not_published_to_resp = min (fetch_resps , key = lambda x : x .highwaterMark )
490+ self .assertNotEquals (published_to_resp .partition , not_published_to_resp .partition )
465491
466- messages = list (fetch_resp [0 ].messages )
492+ self .assertEquals (published_to_resp .error , 0 )
493+ self .assertEquals (published_to_resp .highwaterMark , 1 )
494+
495+ self .assertEquals (not_published_to_resp .error , 0 )
496+ self .assertEquals (not_published_to_resp .highwaterMark , 0 )
497+
498+ messages = list (published_to_resp .messages )
467499 self .assertEquals (len (messages ), 1 )
468500 self .assertEquals (messages [0 ].message .value , "one" )
469501
502+ messages = list (not_published_to_resp .messages )
503+ self .assertEquals (len (messages ), 0 )
504+
470505 producer .stop ()
471506
472507 def test_async_keyed_producer (self ):
0 commit comments