1919from typed_python .compiler .native_compiler .binary_shared_object import BinarySharedObject
2020
2121from typed_python .SerializationContext import SerializationContext
22- from typed_python import Dict , ListOf
22+ from typed_python import Dict , ListOf , Set
2323
2424
2525def ensureDirExists (cacheDir ):
@@ -53,32 +53,65 @@ def __init__(self, cacheDir):
5353 ensureDirExists (cacheDir )
5454
5555 self .loadedModules = Dict (str , LoadedModule )()
56- self .nameToModuleHash = Dict (str , str )()
5756
57+ # for each symbol, the first module we loaded that has that symbol
58+ self .symbolToLoadedModuleHash = Dict (str , str )()
59+
60+ # for each module that we loaded or might load, the contents
61+ self .symbolToModuleHashes = Dict (str , Set (str ))()
62+ self .moduleHashToSymbols = Dict (str , Set (str ))()
63+
64+ # modules we might be able to load
5865 self .modulesMarkedValid = set ()
66+
67+ # modules we definitely can't load
5968 self .modulesMarkedInvalid = set ()
6069
6170 for moduleHash in os .listdir (self .cacheDir ):
6271 if len (moduleHash ) == 40 :
6372 self .loadNameManifestFromStoredModuleByHash (moduleHash )
6473
65- def hasSymbol (self , linkName ):
66- return linkName in self . nameToModuleHash
74+ def hasSymbol (self , symbol ):
75+ """Do we have this symbol defined somewhere?
6776
68- def markModuleHashInvalid (self , hashstr ):
69- with open (os .path .join (self .cacheDir , hashstr , "marked_invalid" ), "w" ):
70- pass
77+ Note that this can change: if we attempt to laod a symbol and fail,
78+ then it may no longer be defined anywhere. To really know if you have
79+ a symbol, you have to load it.
80+ """
81+ return symbol in self .symbolToModuleHashes
7182
72- def loadForSymbol (self , linkName ):
73- moduleHash = self .nameToModuleHash [linkName ]
83+ def markModuleHashInvalid (self , moduleHash ):
84+ """Mark this module unloadable on disk and remove its symbols."""
85+ with open (os .path .join (self .cacheDir , moduleHash , "marked_invalid" ), "w" ):
86+ pass
7487
75- nameToTypedCallTarget = {}
76- nameToNativeFunctionType = {}
88+ # remove any symbols that we can't see anymore
89+ for symbol in self .moduleHashToSymbols .pop (moduleHash , Set (str )()):
90+ hashes = self .symbolToModuleHashes [symbol ]
91+ hashes .discard (moduleHash )
92+ if not hashes :
93+ del self .symbolToModuleHashes [symbol ]
7794
78- if not self .loadModuleByHash (moduleHash , nameToTypedCallTarget , nameToNativeFunctionType ):
95+ def loadForSymbol (self , symbol ):
96+ # check if this symbol is already loaded
97+ if symbol in self .symbolToLoadedModuleHash :
7998 return None
8099
81- return nameToTypedCallTarget , nameToNativeFunctionType
100+ while symbol in self .symbolToModuleHashes :
101+ moduleHash = list (self .symbolToModuleHashes [symbol ])[0 ]
102+
103+ nameToTypedCallTarget = {}
104+ nameToNativeFunctionType = {}
105+
106+ if self .loadModuleByHash (moduleHash , nameToTypedCallTarget , nameToNativeFunctionType ):
107+ return nameToTypedCallTarget , nameToNativeFunctionType
108+ else :
109+ assert (
110+ # either we can't load this symbol at all anymore
111+ symbol not in self .symbolToModuleHashes
112+ # or confirm we can't try to load this again
113+ or moduleHash not in self .symbolToModuleHashes [symbol ]
114+ )
82115
83116 def loadModuleByHash (self , moduleHash , nameToTypedCallTarget , nameToNativeFunctionType ):
84117 """Load a module by name.
@@ -134,6 +167,10 @@ def loadModuleByHash(self, moduleHash, nameToTypedCallTarget, nameToNativeFuncti
134167 nameToTypedCallTarget .update (callTargets )
135168 nameToNativeFunctionType .update (functionNameToNativeType )
136169
170+ for symbol in functionNameToNativeType :
171+ if symbol not in self .symbolToLoadedModuleHash :
172+ self .symbolToLoadedModuleHash [symbol ] = moduleHash
173+
137174 return True
138175
139176 def addModule (self , binarySharedObject , nameToTypedCallTarget , linkDependencies ):
@@ -149,16 +186,19 @@ def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies)
149186 dependentHashes = set ()
150187
151188 for name in linkDependencies :
152- dependentHashes .add (self .nameToModuleHash [name ])
189+ dependentHashes .add (self .symbolToLoadedModuleHash [name ])
153190
154191 path , hashToUse = self .writeModuleToDisk (binarySharedObject , nameToTypedCallTarget , dependentHashes )
155192
156193 self .loadedModules [hashToUse ] = (
157194 binarySharedObject .loadFromPath (os .path .join (path , "module.so" ))
158195 )
159196
160- for n in binarySharedObject .definedSymbols :
161- self .nameToModuleHash [n ] = hashToUse
197+ for symbol in binarySharedObject .definedSymbols :
198+ if symbol not in self .symbolToLoadedModuleHash :
199+ self .symbolToLoadedModuleHash [symbol ] = hashToUse
200+ self .symbolToModuleHashes .setdefault (symbol ).add (hashToUse )
201+ self .moduleHashToSymbols [hashToUse ] = Set (str )(binarySharedObject .definedSymbols )
162202
163203 def loadNameManifestFromStoredModuleByHash (self , moduleHash ):
164204 if moduleHash in self .modulesMarkedValid :
@@ -185,9 +225,11 @@ def loadNameManifestFromStoredModuleByHash(self, moduleHash):
185225 return False
186226
187227 with open (os .path .join (targetDir , "name_manifest.dat" ), "rb" ) as f :
188- self .nameToModuleHash .update (
189- SerializationContext ().deserialize (f .read (), Dict (str , str ))
190- )
228+ manifest = SerializationContext ().deserialize (f .read (), Set (str ))
229+
230+ for symbolName in manifest :
231+ self .symbolToModuleHashes .setdefault (symbolName ).add (moduleHash )
232+ self .moduleHashToSymbols [moduleHash ] = manifest
191233
192234 self .modulesMarkedValid .add (moduleHash )
193235
@@ -225,12 +267,10 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
225267
226268 # write the manifest. Every TP process using the cache will have to
227269 # load the manifest every time, so we try to use compiled code to load it
228- manifest = Dict (str , str )()
229- for n in binarySharedObject .functionTypes :
230- manifest [n ] = hashToUse
270+ manifest = Set (str )(binarySharedObject .functionTypes )
231271
232272 with open (os .path .join (tempTargetDir , "name_manifest.dat" ), "wb" ) as f :
233- f .write (SerializationContext ().serialize (manifest , Dict ( str , str )))
273+ f .write (SerializationContext ().serialize (manifest , Set ( str )))
234274
235275 with open (os .path .join (tempTargetDir , "name_manifest.txt" ), "w" ) as f :
236276 for sourceName in manifest :
@@ -262,7 +302,7 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
262302 return targetDir , hashToUse
263303
264304 def function_pointer_by_name (self , linkName ):
265- moduleHash = self .nameToModuleHash .get (linkName )
305+ moduleHash = self .symbolToLoadedModuleHash .get (linkName )
266306 if moduleHash is None :
267307 raise Exception ("Can't find a module for " + linkName )
268308
0 commit comments