|
| 1 | +from six import print_ |
| 2 | +try: |
| 3 | + import unittest2 as unittest |
| 4 | +except ImportError: |
| 5 | + import unittest |
| 6 | +import hypothesis.strategies as st |
| 7 | +from hypothesis import given |
1 | 8 | from .numbertheory import (SquareRootError, factorization, gcd, lcm, |
2 | 9 | jacobi, inverse_mod, |
3 | 10 | is_prime, next_prime, smallprimes, |
4 | 11 | square_root_mod_prime) |
5 | | -from six import print_ |
6 | 12 |
|
7 | 13 | def test_numbertheory(): |
8 | 14 |
|
@@ -99,26 +105,30 @@ def test_numbertheory(): |
99 | 105 | print_("%d != jacobi( %d, %d )" % (c, a, m)) |
100 | 106 |
|
101 | 107 |
|
102 | | -# Test the inverse_mod function: |
103 | | - print_("Testing inverse_mod . . .") |
104 | | - import random |
105 | | - n_tests = 0 |
106 | | - for i in range(100): |
107 | | - m = random.randint(20, 10000) |
108 | | - for j in range(100): |
109 | | - a = random.randint(1, m - 1) |
110 | | - if gcd(a, m) == 1: |
111 | | - n_tests = n_tests + 1 |
112 | | - inv = inverse_mod(a, m) |
113 | | - if inv <= 0 or inv >= m or (a * inv) % m != 1: |
114 | | - error_tally = error_tally + 1 |
115 | | - print_("%d = inverse_mod( %d, %d ) is wrong." % (inv, a, m)) |
116 | | - assert n_tests > 1000 |
117 | | - print_(n_tests, " tests of inverse_mod completed.") |
118 | | - |
119 | 108 | class FailedTest(Exception): |
120 | 109 | pass |
121 | 110 |
|
122 | 111 | print_(error_tally, "errors detected.") |
123 | 112 | if error_tally != 0: |
124 | 113 | raise FailedTest("%d errors detected" % error_tally) |
| 114 | + |
| 115 | + |
| 116 | +@st.composite |
| 117 | +def st_two_nums_rel_prime(draw): |
| 118 | + # 521-bit is the biggest curve we operate on, use 1024 for a bit |
| 119 | + # of breathing space |
| 120 | + mod = draw(st.integers(min_value=2, max_value=2**1024)) |
| 121 | + num = draw(st.integers(min_value=1, max_value=mod-1) |
| 122 | + .filter(lambda x: gcd(x, mod) == 1)) |
| 123 | + return num, mod |
| 124 | + |
| 125 | + |
| 126 | +class TestNumbertheory(unittest.TestCase): |
| 127 | + @given(st_two_nums_rel_prime()) |
| 128 | + def test_inverse_mod(self, nums): |
| 129 | + num, mod = nums |
| 130 | + |
| 131 | + inv = inverse_mod(num, mod) |
| 132 | + |
| 133 | + assert 0 < inv < mod |
| 134 | + assert num * inv % mod == 1 |
0 commit comments