Skip to content

Commit 9954544

Browse files
committed
More carefully encapsulate the compiler cache within NativeCompiler.
We want for NativeCompiler to provide a simple, clean abstraction for a code-compilation service. You should be able to push code in, ask whether you've pushed code before, and get TypedCallTargets and function pointers out, and not really need to worry about the compiler cache. This change moves the compiler cache fully behind the native compiler, and tries to ensure that the python converter doesn't itself need to track the functions that its compiled in the past. Instead, we're moving towards it being stateless across invocations (and we'll probably do the same thing for the native-to-llvm converter as well) since its substantially simpler to think about. As part of this change, we modify the compiler cache to acknowledge that modules inside might not be loadable and that as a result, it might think that a symbol is available when it is not. In this pathway, the compiler cache tracks all modules that define a symbol, as well as the first module that it loads that defines a symbol, and if its tries to load all possible modules that provide a given symbol and fails, then returns without loading the symbol. NativeCompiler doesn't return a positive indication that it has a symbol until it has actually loaded the symbol. Eventually we may not need this because we'll be loading modules partially.
1 parent 638fe5b commit 9954544

File tree

5 files changed

+205
-96
lines changed

5 files changed

+205
-96
lines changed

typed_python/compiler/native_compiler/binary_shared_object.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def __init__(self, binaryForm, functionTypes, globalVariableDefinitions):
4040
"""
4141
Args:
4242
binaryForm - a bytes object containing the actual compiled code for the module
43+
functionTypes - a map from linkerName to native_ast.Type.Function containing the native
44+
signatures of the functions
4345
globalVariableDefinitions - a map from name to GlobalVariableDefinition
4446
"""
4547
self.binaryForm = binaryForm

typed_python/compiler/native_compiler/compiler_cache.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typed_python.compiler.native_compiler.binary_shared_object import BinarySharedObject
2020

2121
from typed_python.SerializationContext import SerializationContext
22-
from typed_python import Dict, ListOf
22+
from typed_python import Dict, ListOf, Set
2323

2424

2525
def ensureDirExists(cacheDir):
@@ -53,32 +53,65 @@ def __init__(self, cacheDir):
5353
ensureDirExists(cacheDir)
5454

5555
self.loadedModules = Dict(str, LoadedModule)()
56-
self.nameToModuleHash = Dict(str, str)()
5756

57+
# for each symbol, the first module we loaded that has that symbol
58+
self.symbolToLoadedModuleHash = Dict(str, str)()
59+
60+
# for each module that we loaded or might load, the contents
61+
self.symbolToModuleHashes = Dict(str, Set(str))()
62+
self.moduleHashToSymbols = Dict(str, Set(str))()
63+
64+
# modules we might be able to load
5865
self.modulesMarkedValid = set()
66+
67+
# modules we definitely can't load
5968
self.modulesMarkedInvalid = set()
6069

6170
for moduleHash in os.listdir(self.cacheDir):
6271
if len(moduleHash) == 40:
6372
self.loadNameManifestFromStoredModuleByHash(moduleHash)
6473

65-
def hasSymbol(self, linkName):
66-
return linkName in self.nameToModuleHash
74+
def hasSymbol(self, symbol):
75+
"""Do we have this symbol defined somewhere?
6776
68-
def markModuleHashInvalid(self, hashstr):
69-
with open(os.path.join(self.cacheDir, hashstr, "marked_invalid"), "w"):
70-
pass
77+
Note that this can change: if we attempt to laod a symbol and fail,
78+
then it may no longer be defined anywhere. To really know if you have
79+
a symbol, you have to load it.
80+
"""
81+
return symbol in self.symbolToModuleHashes
7182

72-
def loadForSymbol(self, linkName):
73-
moduleHash = self.nameToModuleHash[linkName]
83+
def markModuleHashInvalid(self, moduleHash):
84+
"""Mark this module unloadable on disk and remove its symbols."""
85+
with open(os.path.join(self.cacheDir, moduleHash, "marked_invalid"), "w"):
86+
pass
7487

75-
nameToTypedCallTarget = {}
76-
nameToNativeFunctionType = {}
88+
# remove any symbols that we can't see anymore
89+
for symbol in self.moduleHashToSymbols.pop(moduleHash, Set(str)()):
90+
hashes = self.symbolToModuleHashes[symbol]
91+
hashes.discard(moduleHash)
92+
if not hashes:
93+
del self.symbolToModuleHashes[symbol]
7794

78-
if not self.loadModuleByHash(moduleHash, nameToTypedCallTarget, nameToNativeFunctionType):
95+
def loadForSymbol(self, symbol):
96+
# check if this symbol is already loaded
97+
if symbol in self.symbolToLoadedModuleHash:
7998
return None
8099

81-
return nameToTypedCallTarget, nameToNativeFunctionType
100+
while symbol in self.symbolToModuleHashes:
101+
moduleHash = list(self.symbolToModuleHashes[symbol])[0]
102+
103+
nameToTypedCallTarget = {}
104+
nameToNativeFunctionType = {}
105+
106+
if self.loadModuleByHash(moduleHash, nameToTypedCallTarget, nameToNativeFunctionType):
107+
return nameToTypedCallTarget, nameToNativeFunctionType
108+
else:
109+
assert (
110+
# either we can't load this symbol at all anymore
111+
symbol not in self.symbolToModuleHashes
112+
# or confirm we can't try to load this again
113+
or moduleHash not in self.symbolToModuleHashes[symbol]
114+
)
82115

83116
def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFunctionType):
84117
"""Load a module by name.
@@ -134,6 +167,10 @@ def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFuncti
134167
nameToTypedCallTarget.update(callTargets)
135168
nameToNativeFunctionType.update(functionNameToNativeType)
136169

170+
for symbol in functionNameToNativeType:
171+
if symbol not in self.symbolToLoadedModuleHash:
172+
self.symbolToLoadedModuleHash[symbol] = moduleHash
173+
137174
return True
138175

139176
def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies):
@@ -149,16 +186,19 @@ def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies)
149186
dependentHashes = set()
150187

151188
for name in linkDependencies:
152-
dependentHashes.add(self.nameToModuleHash[name])
189+
dependentHashes.add(self.symbolToLoadedModuleHash[name])
153190

154191
path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes)
155192

156193
self.loadedModules[hashToUse] = (
157194
binarySharedObject.loadFromPath(os.path.join(path, "module.so"))
158195
)
159196

160-
for n in binarySharedObject.definedSymbols:
161-
self.nameToModuleHash[n] = hashToUse
197+
for symbol in binarySharedObject.definedSymbols:
198+
if symbol not in self.symbolToLoadedModuleHash:
199+
self.symbolToLoadedModuleHash[symbol] = hashToUse
200+
self.symbolToModuleHashes.setdefault(symbol).add(hashToUse)
201+
self.moduleHashToSymbols[hashToUse] = Set(str)(binarySharedObject.definedSymbols)
162202

163203
def loadNameManifestFromStoredModuleByHash(self, moduleHash):
164204
if moduleHash in self.modulesMarkedValid:
@@ -185,9 +225,11 @@ def loadNameManifestFromStoredModuleByHash(self, moduleHash):
185225
return False
186226

187227
with open(os.path.join(targetDir, "name_manifest.dat"), "rb") as f:
188-
self.nameToModuleHash.update(
189-
SerializationContext().deserialize(f.read(), Dict(str, str))
190-
)
228+
manifest = SerializationContext().deserialize(f.read(), Set(str))
229+
230+
for symbolName in manifest:
231+
self.symbolToModuleHashes.setdefault(symbolName).add(moduleHash)
232+
self.moduleHashToSymbols[moduleHash] = manifest
191233

192234
self.modulesMarkedValid.add(moduleHash)
193235

@@ -225,12 +267,10 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
225267

226268
# write the manifest. Every TP process using the cache will have to
227269
# load the manifest every time, so we try to use compiled code to load it
228-
manifest = Dict(str, str)()
229-
for n in binarySharedObject.functionTypes:
230-
manifest[n] = hashToUse
270+
manifest = Set(str)(binarySharedObject.functionTypes)
231271

232272
with open(os.path.join(tempTargetDir, "name_manifest.dat"), "wb") as f:
233-
f.write(SerializationContext().serialize(manifest, Dict(str, str)))
273+
f.write(SerializationContext().serialize(manifest, Set(str)))
234274

235275
with open(os.path.join(tempTargetDir, "name_manifest.txt"), "w") as f:
236276
for sourceName in manifest:
@@ -262,7 +302,7 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
262302
return targetDir, hashToUse
263303

264304
def function_pointer_by_name(self, linkName):
265-
moduleHash = self.nameToModuleHash.get(linkName)
305+
moduleHash = self.symbolToLoadedModuleHash.get(linkName)
266306
if moduleHash is None:
267307
raise Exception("Can't find a module for " + linkName)
268308

typed_python/compiler/native_compiler/compiler_cache_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def test_compiler_cache_handles_changed_types():
359359

360360
# if we try to use 'f', it should work even though we no longer have
361361
# a defniition for 'g2'
362-
assert evaluateExprInFreshProcess(VERSION2, 'x.f(1)', compilerCacheDir) == 1
362+
assert evaluateExprInFreshProcess(VERSION2, 'x.f(1)', compilerCacheDir, printComments=True) == 1
363363
assert len(os.listdir(compilerCacheDir)) == 2
364364

365365
badCt = 0

typed_python/compiler/native_compiler/native_compiler.py

Lines changed: 93 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typed_python import OneOf
16+
1517
import llvmlite.binding as llvm
1618
import typed_python.compiler.native_compiler.native_ast as native_ast
1719
import typed_python.compiler.native_compiler.native_ast_to_llvm as native_ast_to_llvm
@@ -20,36 +22,86 @@
2022
from typed_python.compiler.native_compiler.loaded_module import LoadedModule
2123
from typed_python.compiler.native_compiler.native_function_pointer import NativeFunctionPointer
2224
from typed_python.compiler.native_compiler.binary_shared_object import BinarySharedObject
25+
from typed_python.compiler.native_compiler.typed_call_target import TypedCallTarget
2326

2427

2528
class NativeCompiler:
2629
""""Engine for compiling bundles of native_ast.Function objects into NativeFunctionPointers.
2730
2831
This class is responsible for
32+
2933
* telling clients what named functions have been defined and what their types are.
3034
* compiling functions into a runnable form using llvm
3135
* performing any runtime-based performance optimizations
3236
* maintaining the compiler cache
3337
34-
Note that this class is NOT threadsafe and clients are expected to serialize their
35-
access through Runtime.
38+
Note that this class is NOT threadsafe
3639
"""
3740
def __init__(self, inlineThreshold):
3841
self.compilerCache = None
42+
self.hasEverHadFunctionsAdded = False
43+
3944
self.engine, self.module_pass_manager = create_execution_engine(inlineThreshold)
4045
self.converter = native_ast_to_llvm.Converter()
41-
self.functions_by_name = {}
4246
self.inlineThreshold = inlineThreshold
4347
self.verbose = False
4448
self.optimize = True
4549

50+
# map from linkName: str -> NativeFunctionPointer
51+
# this is only populated if we don't have a compiler cache
52+
self.linkNameToFunctionPtr = {}
53+
54+
# map from linkName: str -> TypedCallTarget.
55+
# this contains every typed call target that we have loaded
56+
# but not the ones that are available in the compiler cache but that
57+
# have not yet been loaded
58+
self._allTypedCallTargets = {}
59+
self._allFunctionsDefined = set()
60+
61+
def isFunctionDefined(self, linkName: str) -> bool:
62+
"""Is this function known to the compiler?"""
63+
if self.compilerCache is None:
64+
return linkName in self.linkNameToFunctionPtr
65+
else:
66+
if linkName in self._allFunctionsDefined:
67+
return True
68+
69+
if not self.compilerCache.hasSymbol(linkName):
70+
return False
71+
72+
# the compiler cache has the symbol but we can't currently be sure
73+
# that we can load it - so load it. If it fails, it won't tell us
74+
# that it can load that symbol a second time.
75+
self._loadFromCache(linkName)
76+
77+
return linkName in self._allFunctionsDefined
78+
79+
def typedCallTargetFor(self, linkName: str) -> OneOf(None, TypedCallTarget):
80+
"""If this function is known, its TypedCallTarget, or None.
81+
82+
If this function was added without a TypedCallTarget (say, because its
83+
a destructor or some other untyped function) then this will be None.
84+
"""
85+
if self.compilerCache is None:
86+
# if we have no compiler cache, then _allTypedCallTargets will
87+
# contain everything we've ever seen
88+
return self._allTypedCallTargets.get(linkName)
89+
else:
90+
if linkName in self._allTypedCallTargets:
91+
return self._allTypedCallTargets[linkName]
92+
93+
if linkName not in self._allFunctionsDefined:
94+
if self.compilerCache.hasSymbol(linkName):
95+
self._loadFromCache(linkName)
96+
97+
return self._allTypedCallTargets.get(linkName)
98+
4699
def initializeCompilerCache(self, compilerCacheDir):
47100
"""Indicate that we should use a compiler cache from disk at 'compilerCacheDir'."""
48-
self.compilerCache = CompilerCache(compilerCacheDir)
101+
if self.hasEverHadFunctionsAdded:
102+
raise Exception("Can't set the compiler cache if we've added functions.")
49103

50-
def markExternal(self, functionNameToType):
51-
"""Provide type signatures for a set of external functions."""
52-
self.converter.markExternal(functionNameToType)
104+
self.compilerCache = CompilerCache(compilerCacheDir)
53105

54106
def mark_converter_verbose(self):
55107
self.converter.verbose = True
@@ -69,51 +121,67 @@ def addFunctions(
69121
70122
Once a function has been added, we can request a NativeFunctionPointer for it.
71123
"""
124+
self.hasEverHadFunctionsAdded = True
125+
72126
if self.compilerCache is None:
127+
self._allTypedCallTargets.update(typedCallTargets)
128+
self._allFunctionsDefined.update(functionDefinitions)
129+
73130
loadedModule = self._buildModule(functionDefinitions)
74131
loadedModule.linkGlobalVariables()
75132
else:
76133
binary = self._buildSharedObject(functionDefinitions)
77-
78134
self.compilerCache.addModule(
79135
binary,
80136
typedCallTargets,
81137
externallyUsed
82138
)
83139

84-
def functionPointerByName(self, linkerName) -> NativeFunctionPointer:
140+
def functionPointerByName(self, linkName) -> NativeFunctionPointer:
85141
"""Find a NativeFunctionPointer for a given link-time name.
86142
87143
Args:
88-
linkerName (str) - the name of the compiled symbol we want.
144+
linkName (str) - the name of the compiled symbol we want.
89145
90146
Returns:
91147
a NativeFunctionPointer or None if the function has never been defined.
92148
"""
93149
if self.compilerCache is not None:
94150
# the compiler cache has every shared object and can load them
95-
return self.compilerCache.function_pointer_by_name(linkerName)
151+
if linkName in self.linkNameToFunctionPtr:
152+
return self.linkNameToFunctionPtr[linkName]
153+
154+
if not self.compilerCache.hasSymbol(linkName):
155+
return None
156+
157+
self.compilerCache.loadForSymbol(linkName)
158+
159+
funcPtr = self.compilerCache.function_pointer_by_name(linkName)
160+
161+
assert funcPtr is not None
162+
163+
self.linkNameToFunctionPtr[linkName] = funcPtr
164+
165+
return funcPtr
96166

97167
# the llvm compiler is just building shared objects, but the
98168
# compiler cache has all the pointers.
99-
return self.functions_by_name.get(linkerName)
169+
return self.linkNameToFunctionPtr.get(linkName)
100170

101-
def loadFromCache(self, linkName):
171+
def _loadFromCache(self, linkName):
102172
"""Attempt to load a cached copy of 'linkName' and all reachable code.
103173
104-
If it isn't defined, or has already been defined, return None. If we're loading it
105-
for the first time, return a pair
106-
107-
(typedCallTargets, nativeTypes)
174+
The compilerCache must exist and agree that this function exists.
175+
"""
176+
assert self.compilerCache and self.compilerCache.hasSymbol(linkName)
108177

109-
where typedCallTargets is a map from linkName to TypedCallTarget, and nativeTypes is
110-
a map from linkName to native_ast.Type.Function giving the native implementation type.
178+
callTargetsAndTypes = self.compilerCache.loadForSymbol(linkName)
111179

112-
WARNING: this will return None if you already called 'functionPointerByName' on it
113-
"""
114-
if self.compilerCache:
115-
if self.compilerCache.hasSymbol(linkName):
116-
return self.compilerCache.loadForSymbol(linkName)
180+
if callTargetsAndTypes is not None:
181+
newTypedCallTargets, newNativeFunctionTypes = callTargetsAndTypes
182+
self.converter.markExternal(newNativeFunctionTypes)
183+
self._allTypedCallTargets.update(newTypedCallTargets)
184+
self._allFunctionsDefined.update(newNativeFunctionTypes)
117185

118186
def _buildSharedObject(self, functions):
119187
"""Add native definitions and return a BinarySharedObject representing the compiled code."""
@@ -184,7 +252,7 @@ def _buildModule(self, functions):
184252
native_function_pointers[fname] = NativeFunctionPointer(
185253
fname, func_ptr, input_types, output_type
186254
)
187-
self.functions_by_name[fname] = native_function_pointers[fname]
255+
self.linkNameToFunctionPtr[fname] = native_function_pointers[fname]
188256

189257
native_function_pointers[module.GET_GLOBAL_VARIABLES_NAME] = (
190258
NativeFunctionPointer(

0 commit comments

Comments
 (0)