44except ImportError :
55 import unittest
66import hypothesis .strategies as st
7+ import pytest
78from hypothesis import given , settings , example
89try :
910 from hypothesis import HealthCheck
@@ -56,40 +57,22 @@ def test_numbertheory():
5657 for i in range (len (bigprimes ) - 1 ):
5758 assert next_prime (bigprimes [i ]) == bigprimes [i + 1 ]
5859
59- error_tally = 0
60-
61- # Test the square_root_mod_prime function:
62-
63- for p in smallprimes :
64- print_ ("Testing square_root_mod_prime for modulus p = %d." % p )
65- squares = []
66-
67- for root in range (0 , 1 + p // 2 ):
68- sq = (root * root ) % p
69- squares .append (sq )
70- calculated = square_root_mod_prime (sq , p )
71- if (calculated * calculated ) % p != sq :
72- error_tally = error_tally + 1
73- print_ ("Failed to find %d as sqrt( %d ) mod %d. Said %d." % \
74- (root , sq , p , calculated ))
75-
76- for nonsquare in range (0 , p ):
77- if nonsquare not in squares :
78- try :
79- calculated = square_root_mod_prime (nonsquare , p )
80- except SquareRootError :
81- pass
82- else :
83- error_tally = error_tally + 1
84- print_ ("Failed to report no root for sqrt( %d ) mod %d." % \
85- (nonsquare , p ))
8660
87- class FailedTest (Exception ):
88- pass
61+ @pytest .mark .parametrize ("prime" , smallprimes )
62+ def test_square_root_mod_prime_for_small_primes (prime ):
63+ squares = set ()
64+ for num in range (0 , 1 + prime // 2 ):
65+ sq = num * num % prime
66+ squares .add (sq )
67+ root = square_root_mod_prime (sq , prime )
68+ # tested for real with TestNumbertheory.test_square_root_mod_prime
69+ assert root * root % prime == sq
8970
90- print_ (error_tally , "errors detected." )
91- if error_tally != 0 :
92- raise FailedTest ("%d errors detected" % error_tally )
71+ for nonsquare in range (0 , prime ):
72+ if nonsquare in squares :
73+ continue
74+ with pytest .raises (SquareRootError ):
75+ square_root_mod_prime (nonsquare , prime )
9376
9477
9578@st .composite
@@ -102,6 +85,24 @@ def st_two_nums_rel_prime(draw):
10285 return num , mod
10386
10487
88+ @st .composite
89+ def st_primes (draw , * args , ** kwargs ):
90+ if "min_value" not in kwargs :
91+ kwargs ["min_value" ] = 1
92+ prime = draw (st .sampled_from (smallprimes ) |
93+ st .integers (* args , ** kwargs )
94+ .filter (lambda x : is_prime (x )))
95+ return prime
96+
97+
98+ @st .composite
99+ def st_num_square_prime (draw ):
100+ prime = draw (st_primes (max_value = 2 ** 1024 ))
101+ num = draw (st .integers (min_value = 0 , max_value = 1 + prime // 2 ))
102+ sq = num * num % prime
103+ return sq , prime
104+
105+
105106HYP_SETTINGS = {}
106107if HC_PRESENT :
107108 HYP_SETTINGS ['suppress_health_check' ]= [HealthCheck .filter_too_much ,
@@ -111,6 +112,17 @@ def st_two_nums_rel_prime(draw):
111112
112113
113114class TestNumbertheory (unittest .TestCase ):
115+ @unittest .skipUnless (HC_PRESENT ,
116+ "Hypothesis 2.0.0 can't be made tolerant of hard to "
117+ "meet requirements (like `is_prime()`)" )
118+ @settings (** HYP_SETTINGS )
119+ @given (st_num_square_prime ())
120+ def test_square_root_mod_prime (self , vals ):
121+ square , prime = vals
122+
123+ calc = square_root_mod_prime (square , prime )
124+ assert calc * calc % prime == square
125+
114126 @settings (** HYP_SETTINGS )
115127 @given (st .integers (min_value = 1 , max_value = 10 ** 12 ))
116128 @example (265399 * 1526929 )
0 commit comments