Skip to content

Commit 8a25f12

Browse files
hanno-beckermkannwischer
authored andcommitted
check-magic: Remember all explained magic constants
Previously, scripts/check-magic would remember only the last explained magic constant, preventing, for example, the explanation of multiple magic constants ahead of a comment block referring to all of them. Moreover, check-magic would only lazily evaluate a provided explanation when actually finding a magic value magic the LHS of the proposed explanation. In particular, a _wrong_ explanation would only be caught if, in the rest of the file under consideration, some matching magic constant would be found. This commit makes check-magic more general so that - it always checks magic value explanations when they are provided, regardless of whether they are needed or not; and, - it remembers all magic values explained so far. Moreover, the `round` function is instrumented to fail if it is called on an odd multiple of 1/2 -- in this case, the rounding is ambiguous (do we want round-half-down or round-half-up?). We also add support for `intdiv(a,b)` to an integer division which we want to assert to be without residue. This can be used instead of `//` to additionally check that the division is indeed integral. Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
1 parent cdda8b4 commit 8a25f12

File tree

1 file changed

+39
-12
lines changed

1 file changed

+39
-12
lines changed

scripts/check-magic

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
#
88

99
import re
10+
import math
1011
import pathlib
1112

12-
from sympy import simplify, sympify, Function
13+
from sympy import simplify, sympify, Function, Rational
1314

1415
def get_c_source_files():
1516
return get_files("mlkem/**/*.c")
@@ -20,6 +21,17 @@ def get_header_files():
2021
def get_files(pattern):
2122
return list(map(str, pathlib.Path().glob(pattern)))
2223

24+
# Standard color definitions
25+
GREEN="\033[32m"
26+
RED="\033[31m"
27+
BLUE="\033[94m"
28+
BOLD="\033[1m"
29+
NORMAL="\033[0m"
30+
31+
CHECKED = f"{GREEN}{NORMAL}"
32+
FAIL = f"{RED}{NORMAL}"
33+
REMEMBERED = f"{BLUE}{NORMAL}"
34+
2335
def check_magic_numbers():
2436
mlkem_q = 3329
2537
exceptions = [mlkem_q,
@@ -64,9 +76,21 @@ def check_magic_numbers():
6476
y = int(y)
6577
m = int(m)
6678
return signed_mod(pow(x,y,m),m)
79+
def safe_round(x):
80+
if x - math.floor(x) == Rational(1, 2):
81+
raise ValueError(f"Ambiguous rounding: {x} is an odd multiple of 0.5 and it is unclear if round-up or round-down is desired")
82+
return round(x)
83+
def safe_floordiv(x, y):
84+
x = int(x)
85+
y = int(y)
86+
if x % y != 0:
87+
raise ValueError(f"Non-integral division: {x} // {y} has remainder {x % y}")
88+
return x // y
6789
locals_dict = {'signed_mod': signed_mod,
6890
'unsigned_mod': unsigned_mod,
69-
'pow': pow_mod }
91+
'pow': pow_mod,
92+
'round': safe_round,
93+
'intdiv': safe_floordiv }
7094
locals_dict.update(known_magics)
7195
return sympify(m, locals=locals_dict)
7296

@@ -82,6 +106,7 @@ def check_magic_numbers():
82106
enabled = True
83107
magic_dict = {'MLKEM_Q': mlkem_q}
84108
magic_expr = None
109+
verified_magics = {}
85110
for i, l in enumerate(content):
86111
if enabled is True and disable_marker in l:
87112
enabled = False
@@ -94,6 +119,12 @@ def check_magic_numbers():
94119
l, g = get_magic(l)
95120
if g is not None:
96121
magic_val, magic_expr = g
122+
magic_val_check = evaluate_magic(magic_expr, magic_dict)
123+
if magic_val != magic_val_check:
124+
print(f"{FAIL}:{filename}:{i}: Mismatching magic annotation: {magic_val} != {magic_expr} (= {magic_val_check})")
125+
exit(1)
126+
print(f"{REMEMBERED}:{filename}:{i}: Verified explanation {magic_val} == {magic_expr}")
127+
verified_magics[magic_val] = magic_expr
97128

98129
found = next(re.finditer(pattern, l), None)
99130
if found is None:
@@ -103,16 +134,12 @@ def check_magic_numbers():
103134
if is_exception(filename, l, magic):
104135
continue
105136

106-
if magic_expr is not None:
107-
val = evaluate_magic(magic_expr, magic_dict)
108-
if magic_val != val:
109-
raise Exception(f"{filename}:{i}: Mismatching magic annotation: {magic_val} != {val}")
110-
if val == magic:
111-
print(f"[OK] {filename}:{i}: Verified magic constant {magic} == {magic_expr}")
112-
else:
113-
raise Exception(f"{filename}:{i}: Magic constant mismatch {magic} != {magic_expr}")
114-
else:
115-
raise Exception(f"{filename}:{i}: No explanation for magic value {magic}")
137+
explanation = verified_magics.get(magic, None)
138+
if explanation is None:
139+
print(f"{FAIL}:{filename}:{i}: No explanation for magic value {magic}")
140+
exit(1)
141+
142+
print(f"{CHECKED}:{filename}:{i}: {magic} previously explained as {explanation}")
116143

117144
# If this is a #define's clause, remember it
118145
define = get_define(l)

0 commit comments

Comments
 (0)