Skip to content

Commit 5de7f05

Browse files
committed
Add advanced source transformations to reduce type checking overhead
The new 'munge' module performs transformations on the source code. It uses the AST (abstract syntax tree) representation of Python code to recognize some idioms such as `if STATIC_TYPING:` and transforms them into alternatives that have zero overhead in mpy-compiled files (e.g., `if STATIC_TYPING:` is transformed into `if 0:`, which is eliminated at compile time due to mpy-cross constant-propagation and dead branch elimination) The code assumes the input file is black-formatted. In particular, it would malfunction if an if-statement and its body are on the same line: `if STATIC_TYPING: print("boo")` would be incorrectly munged.
1 parent ff6d40d commit 5de7f05

File tree

2 files changed

+134
-29
lines changed

2 files changed

+134
-29
lines changed

circuitpython_build_tools/build.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
import subprocess
3737
import tempfile
3838

39+
from .munge import munge
40+
3941
# pyproject.toml `py_modules` values that are incorrect. These should all have PRs filed!
4042
# and should be removed when the fixed version is incorporated in its respective bundle.
4143

@@ -170,16 +172,6 @@ def mpy_cross(mpy_cross_filename, circuitpython_tag, quiet=False):
170172

171173
shutil.copy("build_deps/circuitpython/mpy-cross/mpy-cross", mpy_cross_filename)
172174

173-
def _munge_to_temp(original_path, temp_file, library_version):
174-
with open(original_path, "r", encoding="utf-8") as original_file:
175-
for line in original_file:
176-
line = line.strip("\n")
177-
if line.startswith("__version__"):
178-
line = line.replace("0.0.0-auto.0", library_version)
179-
line = line.replace("0.0.0+auto.0", library_version)
180-
print(line, file=temp_file)
181-
temp_file.flush()
182-
183175
def get_package_info(library_path, package_folder_prefix):
184176
lib_path = pathlib.Path(library_path)
185177
parent_idx = len(lib_path.parts)
@@ -295,25 +287,22 @@ def library(library_path, output_directory, package_folder_prefix,
295287
full_path = os.path.join(library_path, filename)
296288
output_file = output_directory / filename.relative_to(library_path)
297289
if filename.suffix == ".py":
298-
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as temp_file:
299-
temp_file_name = temp_file.name
300-
try:
301-
_munge_to_temp(full_path, temp_file, library_version)
302-
temp_file.close()
303-
if mpy_cross and os.stat(temp_file.name).st_size != 0:
304-
output_file = output_file.with_suffix(".mpy")
305-
mpy_success = subprocess.call([
306-
mpy_cross,
307-
"-o", output_file,
308-
"-s", str(filename.relative_to(library_path)),
309-
temp_file.name
310-
])
311-
if mpy_success != 0:
312-
raise RuntimeError("mpy-cross failed on", full_path)
313-
else:
314-
shutil.copyfile(full_path, output_file)
315-
finally:
316-
os.remove(temp_file_name)
290+
content = munge(full_path, library_version)
291+
if mpy_cross and content:
292+
# TODO: Once 8.x bundles are no longer built, switch to
293+
# sending mpy-cross the code on stdin instead of via
294+
# temporary file (supports the "-" input argument)
295+
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as temp_file:
296+
temp_file.write(content)
297+
temp_file.flush()
298+
subprocess.check_output([
299+
mpy_cross,
300+
"-o", output_file.with_suffix(".mpy"),
301+
"-s", str(filename.relative_to(library_path)),
302+
temp_file.name
303+
], input=content.encode('utf-8'))
304+
else:
305+
output_file.write_text(content, encoding="utf-8")
317306
else:
318307
shutil.copyfile(full_path, output_file)
319308

circuitpython_build_tools/munge.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# The MIT License (MIT)
2+
#
3+
# Copyright (c) 2024 Jeff Epler for Adafruit Industries
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
#
12+
# The above copyright notice and this permission notice shall be included in
13+
# all copies or substantial portions of the Software.
14+
#
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21+
# THE SOFTWARE.
22+
23+
# Filter program removes some code patterns introduced by type checking,
24+
# to move towards zero overhead static typing in circuitpython libraries
25+
#
26+
# Recognized:
27+
# from __future__ import ... -- eliminated
28+
# try: import typing -- eliminated, but first except: preserved
29+
# try: from typing import ... -- eliminated, but first except: preserved
30+
# if STATIC_TYPING: -- transformed to 'if 0:'
31+
# if sys.implementation_name... -- transformed to unconditional if
32+
# __version__ = ... -- set to library version string
33+
#
34+
# mpy-cross does constant propagation and dead branch elimination of
35+
# 'if 0:' and 'if 1:'
36+
#
37+
# Depends on the file being black-formatted!
38+
39+
import pathlib
40+
import sys
41+
import ast
42+
43+
VERBOSE = 0
44+
45+
# The canonical spelling of this test...
46+
sys_implementation_is_circuitpython = ast.unparse(ast.parse('sys.implementation.name == "circuitpython"'))
47+
sys_implementation_not_circuitpython = ast.unparse(ast.parse('not sys.implementation.name == "circuitpython"'))
48+
sys_implementation_not_circuitpython2 = ast.unparse(ast.parse('sys.implementation.name != "circuitpython"'))
49+
50+
def munge(src: pathlib.Path|str, version_str: str) -> str:
51+
path = pathlib.Path(src)
52+
replacements = {}
53+
54+
def replace(line, new):
55+
if VERBOSE:
56+
replacements[line] = f"{new:<40s} ### {lines[line]}"
57+
else:
58+
replacements[line] = new
59+
60+
def blank_range(node):
61+
for i in range(node.lineno, node.end_lineno+1):
62+
replace(i, "")
63+
64+
def unblank_range(node):
65+
for i in range(node.lineno, node.end_lineno+1):
66+
replacements.pop(i, None)
67+
68+
def imports_from_typing(node):
69+
if isinstance(node, ast.Import) and node.names[0].name == 'typing':
70+
return True
71+
if isinstance(node, ast.ImportFrom) and node.module == 'typing':
72+
return True
73+
return False
74+
75+
def process_statement(node):
76+
# filter out 'from future import...'
77+
if isinstance(node, ast.ImportFrom):
78+
if node.module == '__future__':
79+
blank_range(node)
80+
# filter out 'try: import typing...'
81+
# but preserve the first 'except:' or 'except ImportError'
82+
elif isinstance(node, ast.Try):
83+
b = node.body[0]
84+
if imports_from_typing(node.body[0]):
85+
blank_range(node)
86+
for h in node.handlers:
87+
if h.type is None or ast.unparse(h.type) == 'ImportError' or ast.unparse(h.type) == 'Exception':
88+
unblank_range(h)
89+
replace(h.lineno, 'if 1:')
90+
break
91+
return
92+
elif isinstance(node, ast.If):
93+
# return the statements in the 'if' branch of 'if sys.implementation...: ...'
94+
if ast.unparse(node.test) == sys_implementation_is_circuitpython:
95+
replace(node.lineno, 'if 1:')
96+
# return the statements in the 'else' branch of 'if sys.implementation...: ...'
97+
if ast.unparse(node.test) == sys_implementation_not_circuitpython or ast.unparse(node.test) == sys_implementation_not_circuitpython2:
98+
replace(node.lineno, 'if 0:')
99+
# return the statements in the else branch of 'if TYPE_CHECKING: ...'
100+
elif ast.unparse(node.test) == 'TYPE_CHECKING':
101+
replace(node.lineno, 'if 0:')
102+
elif isinstance(node, ast.Assign) and node.targets[0].id == '__version__':
103+
replace(node.lineno, f"__version__ = \"{version_str}\"")
104+
105+
content = pathlib.Path(path).read_text(encoding="utf-8")
106+
# Insert a blank line 0 because ast line numbers are 1-based
107+
lines = [''] + content.rstrip().split('\n')
108+
a = ast.parse(content, path.name)
109+
110+
for node in a.body: process_statement(node)
111+
112+
result = []
113+
for i in range(1, len(lines)):
114+
result.append(replacements.get(i, lines[i]))
115+
116+
return "\n".join(result) + "\n"

0 commit comments

Comments
 (0)