Skip to content

Commit 7bdff98

Browse files
committed
NativeCompiler/compiler_cache pass function definitions to the llvm compiler to allow for cross-module-inlining.
1 parent 89cc6ef commit 7bdff98

File tree

10 files changed

+440
-105
lines changed

10 files changed

+440
-105
lines changed

typed_python/compiler/native_compiler/binary_shared_object.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,38 +36,49 @@ def __init__(self, binarySharedObject, diskPath, functionPointers, globalVariabl
3636
class BinarySharedObject:
3737
"""Models a shared object library (.so) loadable on linux systems."""
3838

39-
def __init__(self, binaryForm, functionTypes, globalVariableDefinitions, usedExternalFunctions):
39+
def __init__(
40+
self,
41+
binaryForm,
42+
functionTypes,
43+
globalVariableDefinitions,
44+
usedExternalFunctions,
45+
functionDefinitions,
46+
):
4047
"""
4148
Args:
4249
binaryForm - a bytes object containing the actual compiled code for the module
4350
functionTypes - a map from linkerName to native_ast.Type.Function containing the native
4451
signatures of the functions
4552
globalVariableDefinitions - a map from name to GlobalVariableDefinition
53+
usedExternalFunctions - a list of symbols defined in other modules that we need
54+
functionDefinitions - dict from symbol to native_ast.Function object
4655
"""
4756
self.binaryForm = binaryForm
4857
self.functionTypes = functionTypes
4958
self.globalVariableDefinitions = globalVariableDefinitions
5059
self.usedExternalFunctions = usedExternalFunctions
60+
self.functionDefinitions = functionDefinitions
5161
self.hash = sha_hash(binaryForm)
5262

5363
@property
5464
def definedSymbols(self):
5565
return self.functionTypes.keys()
5666

5767
@staticmethod
58-
def fromDisk(path, globalVariableDefinitions, functionNameToType, usedExternalFunctions):
68+
def fromDisk(path, globalVariableDefinitions, functionNameToType, usedExternalFunctions, functionDefinitions):
5969
with open(path, "rb") as f:
6070
binaryForm = f.read()
6171

6272
return BinarySharedObject(
6373
binaryForm,
6474
functionNameToType,
6575
globalVariableDefinitions,
66-
usedExternalFunctions
76+
usedExternalFunctions,
77+
functionDefinitions
6778
)
6879

6980
@staticmethod
70-
def fromModule(module, globalVariableDefinitions, functionNameToType, usedExternalFunctions):
81+
def fromModule(module, globalVariableDefinitions, functionNameToType, usedExternalFunctions, functionDefinitions):
7182
target_triple = llvm.get_process_triple()
7283
target = llvm.Target.from_triple(target_triple)
7384
target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default')
@@ -92,7 +103,8 @@ def fromModule(module, globalVariableDefinitions, functionNameToType, usedExtern
92103
so_file.read(),
93104
functionNameToType,
94105
globalVariableDefinitions,
95-
usedExternalFunctions
106+
usedExternalFunctions,
107+
functionDefinitions
96108
)
97109

98110
def load(self, storageDir):

typed_python/compiler/native_compiler/compiler_cache.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import os
1616
import uuid
1717
import shutil
18+
from typed_python.compiler.native_compiler.native_ast import Function
19+
from typed_python.compiler.native_compiler.native_ast_analysis import extractNamedCallTargets
1820
from typed_python.compiler.native_compiler.loaded_module import LoadedModule
1921
from typed_python.compiler.native_compiler.binary_shared_object import BinarySharedObject
2022

@@ -47,8 +49,9 @@ class CompilerCache:
4749
by making it possible to determine if a given function is in the cache by organizing
4850
the manifests by, say, function name.
4951
"""
50-
def __init__(self, cacheDir):
52+
def __init__(self, cacheDir, checkModuleValidity=True):
5153
self.cacheDir = cacheDir
54+
self.checkModuleValidity = checkModuleValidity
5255

5356
ensureDirExists(cacheDir)
5457

@@ -102,9 +105,10 @@ def loadForSymbol(self, symbol):
102105

103106
nameToTypedCallTarget = {}
104107
nameToNativeFunctionType = {}
108+
nameToDefinition = {}
105109

106-
if self.loadModuleByHash(moduleHash, nameToTypedCallTarget, nameToNativeFunctionType):
107-
return nameToTypedCallTarget, nameToNativeFunctionType
110+
if self.loadModuleByHash(moduleHash, nameToTypedCallTarget, nameToNativeFunctionType, nameToDefinition):
111+
return nameToTypedCallTarget, nameToNativeFunctionType, nameToDefinition
108112
else:
109113
assert (
110114
# either we can't load this symbol at all anymore
@@ -113,7 +117,13 @@ def loadForSymbol(self, symbol):
113117
or moduleHash not in self.symbolToModuleHashes[symbol]
114118
)
115119

116-
def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFunctionType):
120+
def loadModuleByHash(
121+
self,
122+
moduleHash,
123+
nameToTypedCallTarget,
124+
nameToNativeFunctionType,
125+
nameToDefinition
126+
):
117127
"""Load a module by name.
118128
119129
As we load, place all the newly imported typed call targets into
@@ -141,6 +151,11 @@ def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFuncti
141151
with open(os.path.join(targetDir, "linkDependencies.dat"), "rb") as f:
142152
linkDependencies = SerializationContext().deserialize(f.read(), ListOf(str))
143153

154+
with open(os.path.join(targetDir, "functionDefinitions.dat"), "rb") as f:
155+
functionDefinitions = SerializationContext().deserialize(
156+
f.read(), Dict(str, Function)
157+
)
158+
144159
except Exception:
145160
self.markModuleHashInvalid(moduleHash)
146161
return False
@@ -154,7 +169,8 @@ def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFuncti
154169
if not self.loadModuleByHash(
155170
submodule,
156171
nameToTypedCallTarget,
157-
nameToNativeFunctionType
172+
nameToNativeFunctionType,
173+
nameToDefinition
158174
):
159175
return False
160176

@@ -164,13 +180,15 @@ def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFuncti
164180
modulePath,
165181
globalVarDefs,
166182
functionNameToNativeType,
167-
linkDependencies
183+
linkDependencies,
184+
functionDefinitions
168185
).loadFromPath(modulePath)
169186

170187
self.loadedModules[moduleHash] = loaded
171188

172189
nameToTypedCallTarget.update(callTargets)
173190
nameToNativeFunctionType.update(functionNameToNativeType)
191+
nameToDefinition.update(functionDefinitions)
174192

175193
for symbol in functionNameToNativeType:
176194
if symbol not in self.symbolToLoadedModuleHash:
@@ -188,6 +206,37 @@ def addModule(self, binarySharedObject, nameToTypedCallTarget):
188206
the formal python types for all the objects
189207
linkDependencies - a set of linknames we depend on directly.
190208
"""
209+
if self.checkModuleValidity:
210+
externals = extractNamedCallTargets(
211+
binarySharedObject.functionDefinitions
212+
)
213+
214+
statedNames = set(binarySharedObject.usedExternalFunctions)
215+
216+
expectedNames = set(e.name for e in externals if not e.external) - set(
217+
binarySharedObject.functionDefinitions
218+
)
219+
220+
if statedNames != expectedNames:
221+
if expectedNames - statedNames:
222+
raise Exception(
223+
"Invalid shared object - link dependencies don't match "
224+
+ "stated shared object dependencies:\n\n"
225+
+ "".join(
226+
[' ' + x + "\n" for x in sorted(expectedNames - statedNames)]
227+
)
228+
+ "\nwere referenced but not claimed in the manifest."
229+
)
230+
else:
231+
raise Exception(
232+
"Invalid shared object - link dependencies don't match "
233+
+ "stated shared object dependencies:\n\n"
234+
+ "".join(
235+
[' ' + x + "\n" for x in sorted(statedNames - expectedNames)]
236+
)
237+
+ "\nwere claimed in the manifest but don't seem to be referenced"
238+
)
239+
191240
dependentHashes = set()
192241

193242
for name in binarySharedObject.usedExternalFunctions:
@@ -308,6 +357,14 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
308357
)
309358
)
310359

360+
with open(os.path.join(tempTargetDir, "functionDefinitions.dat"), "wb") as f:
361+
f.write(
362+
SerializationContext().serialize(
363+
Dict(str, Function)(binarySharedObject.functionDefinitions),
364+
Dict(str, Function)
365+
)
366+
)
367+
311368
try:
312369
os.rename(tempTargetDir, targetDir)
313370
except IOError:
@@ -324,6 +381,6 @@ def function_pointer_by_name(self, linkName):
324381
raise Exception("Can't find a module for " + linkName)
325382

326383
if moduleHash not in self.loadedModules:
327-
self.loadForSymbol(linkName)
384+
raise Exception("You need to call 'loadForSymbol' on this linkName first")
328385

329386
return self.loadedModules[moduleHash].functionPointers[linkName]

typed_python/compiler/native_compiler/compiler_cache_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,57 @@ def test_compiler_cache_can_handle_conflicting_versions_of_the_same_code():
5050
assert len(os.listdir(compilerCacheDir)) == 2
5151

5252

53+
@pytest.mark.skipif('sys.platform=="darwin"')
54+
def test_compiler_cache_inlining():
55+
fListDef = "\n".join([
56+
"from typed_python import Entrypoint",
57+
"@Entrypoint",
58+
"def fList(aList):",
59+
" res = 0",
60+
" for x in aList:",
61+
" res += f(x)",
62+
" return res",
63+
])
64+
xmodule = "\n".join([
65+
"def f(x):",
66+
" return x",
67+
fListDef,
68+
])
69+
ymodule = "\n".join([
70+
"from x import f",
71+
fListDef,
72+
])
73+
testmodule = "\n".join([
74+
"import time",
75+
"from typed_python import Entrypoint, ListOf",
76+
"from x import fList",
77+
"@Entrypoint",
78+
"def makeRange(N):",
79+
" res = ListOf(int)()",
80+
" for i in range(N):",
81+
" res.append(i)",
82+
" return res",
83+
"N = 10000000",
84+
"bigRange = makeRange(N)",
85+
"fList(bigRange)",
86+
"t0 = time.time()",
87+
"fList(bigRange)",
88+
"duration = time.time() - t0",
89+
])
90+
91+
VERSION1 = {'x.py': xmodule, 'testmodule.py': testmodule}
92+
VERSION2 = {'x.py': xmodule, 'y.py': ymodule, 'testmodule.py': testmodule.replace('x', 'y')}
93+
94+
with tempfile.TemporaryDirectory() as compilerCacheDir:
95+
dur1 = evaluateExprInFreshProcess(VERSION1, 'testmodule.duration', compilerCacheDir)
96+
dur2 = evaluateExprInFreshProcess(VERSION2, 'testmodule.duration', compilerCacheDir)
97+
98+
print(dur1)
99+
print(dur2)
100+
101+
assert dur2 < 2 * dur1
102+
103+
53104
@pytest.mark.skipif('sys.platform=="darwin"')
54105
def test_compiler_cache_can_detect_invalidation_through_modules():
55106
xmodule = "\n".join([

typed_python/compiler/native_compiler/module_definition.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ModuleDefinition:
2424
globalDefinitions - a dict from name to a GlobalDefinition
2525
usedExternalFunctions - a set of symbols of functions that were compiled
2626
in different modules that we depend on. These will all be in some module.
27+
functionDefinitions - a dict from name to Function
2728
"""
2829
GET_GLOBAL_VARIABLES_NAME = ".get_global_variables"
2930

@@ -32,10 +33,12 @@ def __init__(
3233
moduleText,
3334
functionNameToType,
3435
globalVariableDefinitions,
35-
usedExternalFunctions
36+
usedExternalFunctions,
37+
functionDefinitions
3638
):
3739
self.moduleText = moduleText
3840
self.functionNameToType = functionNameToType
3941
self.globalVariableDefinitions = globalVariableDefinitions
4042
self.hash = sha_hash(moduleText)
4143
self.usedExternalFunctions = usedExternalFunctions
44+
self.functionDefinitions = functionDefinitions

typed_python/compiler/native_compiler/native_ast.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,6 @@ def expr_could_throw(self):
569569
Binop={'op': BinaryOp, 'left': Expression, 'right': Expression},
570570
Unaryop={'op': UnaryOp, 'operand': Expression},
571571
Variable={'name': str},
572-
Attribute={'left': Expression, 'attr': str},
573572
StructElementByIndex={'left': Expression, 'index': int},
574573
ElementPtr={'left': Expression, 'offsets': TupleOf(Expression)},
575574
Call={'target': CallTarget, 'args': TupleOf(Expression)},
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2023 typed_python Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from typed_python import TupleOf, ListOf, Tuple, Dict, ConstDict, NamedTuple
17+
18+
from typed_python.compiler.native_compiler.native_ast import (
19+
Constant,
20+
Type,
21+
UnaryOp,
22+
BinaryOp,
23+
NamedCallTarget,
24+
Expression,
25+
Teardown,
26+
ExpressionIntermediate,
27+
CallTarget,
28+
FunctionBody,
29+
Function,
30+
GlobalVariableMetadata
31+
)
32+
33+
34+
def visitAstChildren(node, callback):
35+
if not callback(node):
36+
return
37+
38+
# don't look in these
39+
if isinstance(node, (UnaryOp, BinaryOp, Constant, Type, NamedCallTarget, GlobalVariableMetadata)):
40+
return
41+
42+
if isinstance(node, (int, float, str, bytes, bool, type(None))):
43+
return
44+
45+
if isinstance(node, Function):
46+
visitAstChildren(node.args, callback)
47+
visitAstChildren(node.body, callback)
48+
visitAstChildren(node.output_type, callback)
49+
return
50+
51+
if isinstance(node, (Expression, Teardown, ExpressionIntermediate, CallTarget, FunctionBody)):
52+
for name in node.ElementType.ElementNames:
53+
visitAstChildren(getattr(node, name), callback)
54+
return
55+
56+
if isinstance(node, (Dict, ConstDict, dict)):
57+
for k, v in node.items():
58+
visitAstChildren(k, callback)
59+
visitAstChildren(v, callback)
60+
return
61+
62+
if isinstance(node, (TupleOf, ListOf, tuple, list, Tuple, NamedTuple)):
63+
for child in node:
64+
visitAstChildren(child, callback)
65+
return
66+
67+
raise Exception(f"Unexpected AST node of type {type(node).__name__}")
68+
69+
70+
def extractNamedCallTargets(ast):
71+
targets = set()
72+
73+
def check(node):
74+
if isinstance(node, NamedCallTarget):
75+
targets.add(node)
76+
return True
77+
78+
visitAstChildren(ast, check)
79+
80+
return targets

0 commit comments

Comments
 (0)