1+ import logging
12import random
23import time
34import unittest
2829
2930
3031class TestRollbackError (Exception ):
31- pass
32+ __test__ = False # Silence pytest warning
3233
3334
3435class TestableThread (Thread ):
@@ -38,6 +39,8 @@ class TestableThread(Thread):
3839 REF: https://gist.github.com/sbrugman/59b3535ebcd5aa0e2598293cfa58b6ab
3940 """
4041
42+ __test__ = False # Silence pytest warning
43+
4144 def __init__ (self , * args , ** kwargs ):
4245 super ().__init__ (* args , ** kwargs )
4346 self .exc = None
@@ -753,14 +756,15 @@ def run_tx():
753756
754757 except TestRollbackError :
755758 pass
756- # except OperationError as op_failure:
757- # """
758- # See thread safety test below for more details about TransientTransactionError handling
759- # """
760- # if "TransientTransactionError" in str(op_failure):
761- # run_tx()
762- # else:
763- # raise op_failure
759+ except OperationError as op_failure :
760+ """
761+ See thread safety test below for more details about TransientTransactionError handling
762+ """
763+ if "TransientTransactionError" in str (op_failure ):
764+ logging .warning ("TransientTransactionError - retrying..." )
765+ run_tx ()
766+ else :
767+ raise op_failure
764768
765769 run_tx ()
766770 assert "a" == A .objects .get (id = a_doc .id ).name
@@ -789,10 +793,10 @@ def test_thread_safety_of_transactions(self):
789793 case, then no amount of runtime variability should have
790794 an effect on the output.
791795
792- This test sets up 10 records, each with an integer field
796+ This test sets up e.g 10 records, each with an integer field
793797 of value 0 - 9.
794798
795- We then spin up 10 threads and attempt to update a target
799+ We then spin up e.g 10 threads and attempt to update a target
796800 record by multiplying its integer value by 10. Then, if
797801 the target record is even, throw an exception, which
798802 should then roll the transaction back. The odd rows always
@@ -807,24 +811,26 @@ def test_thread_safety_of_transactions(self):
807811 connect ("mongoenginetest" )
808812
809813 class A (Document ):
810- i = IntField ()
814+ i = IntField (unique = True )
811815
812816 A .drop_collection ()
813817 # Ensure the collection is created
814- A .objects .create (i = 0 )
818+ _ = A .objects .first ()
819+
820+ thread_count = 20
815821
816822 def thread_fn (idx ):
817823 # Open the transaction at some unknown interval
818824 time .sleep (random .uniform (0.1 , 0.5 ))
819825 try :
820826 with run_in_transaction ():
821827 a = A .objects .get (i = idx )
822- a .i = idx * 10
828+ a .i = idx * thread_count
823829 # Save at some unknown interval
824830 time .sleep (random .uniform (0.1 , 0.5 ))
825831 a .save ()
826832
827- # Force roll backs for the even runs...
833+ # Force rollbacks for the even runs...
828834 if idx % 2 == 0 :
829835 raise TestRollbackError ()
830836 except TestRollbackError :
@@ -841,6 +847,7 @@ def thread_fn(idx):
841847 """
842848 error_labels = op_failure .details .get ("errorLabels" , [])
843849 if "TransientTransactionError" in error_labels :
850+ logging .warning ("TransientTransactionError - retrying..." )
844851 thread_fn (idx )
845852 else :
846853 raise op_failure
@@ -854,7 +861,6 @@ def thread_fn(idx):
854861 A .objects .all ().delete ()
855862
856863 # Prepopulate the data for reads
857- thread_count = 20
858864 for i in range (thread_count ):
859865 A .objects .create (i = i )
860866
@@ -868,13 +874,10 @@ def thread_fn(idx):
868874 t .join ()
869875
870876 # Check the sum
871- expected_sum = 0
872- for i in range (thread_count ):
873- if i % 2 == 0 :
874- expected_sum += i
875- else :
876- expected_sum += i * 10
877- assert expected_sum == 1090
877+ expected_sum = sum (
878+ i if i % 2 == 0 else i * thread_count for i in range (thread_count )
879+ )
880+ assert expected_sum == 2090
878881 assert expected_sum == sum (a .i for a in A .objects .all ())
879882
880883
0 commit comments