Skip to content

Commit b74d274

Browse files
committed
Allow all Function objects to implicitly depend on a '__module_hash__' member of their globals.
In some downstream use cases, we end up with Function objects that have been created with globals dicts that are not backed by actual modules. In this case, if those functions reference global variables that are mutable, we have no way of differentiating two functions with identical behavior but different globals. This is because normally we don't need to differentiate between them because they get access to their globals by virtue of their reference to the global module dict, and only one such object can exist with a given name in memory at once. But if a function's globals are not named, we need a way of differentiating the dicts. To support this, we introduce a __module_hash__ variable - all TP function objects (Entrypoint, NotCompiled, Function, and Class methods) have an implicit dependency on this global variable. This allows downstream code to set this value and prevent the two functions from being the same.
1 parent aa88ad8 commit b74d274

File tree

5 files changed

+117
-2
lines changed

5 files changed

+117
-2
lines changed

typed_python/FunctionType.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,9 @@ class Function : public Type {
10761076
uint8_t* bytes;
10771077
Py_ssize_t bytecount;
10781078

1079+
static PyObject* moduleHashName = PyUnicode_FromString("__module_hash__");
1080+
outSequences.push_back(std::vector<PyObject*>({moduleHashName}));
1081+
10791082
PyBytes_AsStringAndSize(((PyCodeObject*)code)->co_code, (char**)&bytes, &bytecount);
10801083

10811084
long opcodeCount = bytecount / 2;
@@ -1172,6 +1175,8 @@ class Function : public Type {
11721175
extractGlobalAccessesFromCode((PyCodeObject*)o, outAccesses);
11731176
}
11741177
});
1178+
1179+
outAccesses.insert("__module_hash__");
11751180
}
11761181

11771182
static void extractNamesFromCode(PyCodeObject* code, std::set<PyObject*>& outNames) {
@@ -1183,6 +1188,9 @@ class Function : public Type {
11831188
extractNamesFromCode((PyCodeObject*)o, outNames);
11841189
}
11851190
});
1191+
1192+
static PyObject* moduleHashName = PyUnicode_FromString("__module_hash__");
1193+
outNames.insert(moduleHashName);
11861194
}
11871195

11881196
void setGlobals(PyObject* globals) {

typed_python/SerializationContext.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ class SerializationContext(Class, Final):
130130
nameForObjectOverride = Member(OneOf(None, object))
131131
objectFromNameOverride = Member(OneOf(None, object))
132132

133+
# these are used to control how we serialize specific types
134+
representationForOverride = Member(OneOf(None, object))
135+
setInstanceStateOverride = Member(OneOf(None, object))
136+
133137
def __init__(
134138
self,
135139
nameToObjectOverride=None,
@@ -548,6 +552,11 @@ def representationFor(self, inst):
548552
Returns:
549553
a representation tuple, or None
550554
'''
555+
if self.representationForOverride is not None:
556+
representation = self.representationForOverride(inst)
557+
if representation is not None:
558+
return representation
559+
551560
if type(inst) in (tuple, list, dict, set, str, int, bool, float):
552561
return None
553562

@@ -682,7 +691,7 @@ def representationFor(self, inst):
682691
else:
683692
globalsToUse = {}
684693

685-
all_names = set(['__builtins__'])
694+
all_names = set(['__builtins__', '__module_hash__'])
686695

687696
def walkCodeObject(code):
688697
all_names.update(code.co_names)
@@ -736,6 +745,12 @@ def walkCodeObject(code):
736745
def setInstanceStateFromRepresentation(
737746
self, instance, representation=None, itemIt=None, kvPairIt=None, setStateFun=None
738747
):
748+
if self.setInstanceStateOverride is not None:
749+
if self.setInstanceStateOverride(
750+
instance, representation, itemIt, kvPairIt, setStateFun
751+
):
752+
return
753+
739754
if representation is reconstructTypeFunctionType:
740755
return
741756

typed_python/SpecialModuleNames.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ inline bool isSpecialIgnorableName(const std::string& name) {
9191
"__rpow__", "__rrshift__", "__rshift__", "__rsub__",
9292
"__rtruediv__", "__rxor__", "__setattr__", "__setitem__",
9393
"__str__", "__sub__", "__truediv__", "__xor__",
94+
// this is not a real python magic method, but we want to ensure that
95+
// if this is present in a function's globals that we use it to make the
96+
// function unique.
97+
"__module_hash__"
9498
});
9599

96100
return (

typed_python/internals.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,9 @@ def realizedGlobals(self):
484484
if name.split(".")[0] in globalNames:
485485
res[name] = self.functionGlobals[name]
486486

487+
if '__module_hash__' in self.functionGlobals:
488+
res['__module_hash__'] = self.functionGlobals['__module_hash__']
489+
487490
for varname, cell in self.funcGlobalsInCells.items():
488491
res[varname] = cell.cell_contents
489492

typed_python/type_identity_test.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515
import threading
1616
import pytest
17+
import tempfile
18+
import os
1719

1820
import typed_python
19-
from typed_python.test_util import evaluateExprInFreshProcess
21+
from typed_python.test_util import evaluateExprInFreshProcess, callFunctionInFreshProcess
2022
from typed_python import (
2123
UInt64, UInt32,
2224
ListOf, TupleOf, Tuple, NamedTuple, Dict, OneOf, Forward, identityHash,
@@ -699,3 +701,86 @@ def f(self):
699701
return 0
700702

701703
print(typeWalkRecord(N))
704+
705+
706+
def test_module_hash_magic_value():
707+
with tempfile.TemporaryDirectory() as tempDir:
708+
709+
def makeFun(mh):
710+
"""Produce a dummy function with __module_hash__ of 'mh'
711+
712+
Note that we need to produce this function with 'backing code' so that
713+
the AST system can find it and serialize it.
714+
"""
715+
globalsDict = {}
716+
fname = os.path.join(tempDir, "code_" + mh + ".py")
717+
718+
pyCode = (
719+
f"from typed_python import Function\n"
720+
f"__module_hash__ = '{mh}'\n"
721+
f"@Function\n"
722+
f"def f(x):\n"
723+
f" return x\n"
724+
)
725+
726+
with open(fname, "w") as f:
727+
f.write(pyCode)
728+
729+
exec(compile(pyCode, fname, "exec"), globalsDict)
730+
return globalsDict['f']
731+
732+
def makeFunIH(mh):
733+
return identityHash(type(makeFun(mh)))
734+
735+
# check that the identity hash depends on the module hash
736+
assert identityHash(type(makeFun('A'))) == identityHash(type(makeFun('A')))
737+
assert identityHash(type(makeFun('A'))) != identityHash(type(makeFun('B')))
738+
739+
# functions should have this in their globals
740+
f = makeFun('A')
741+
assert '__module_hash__' in f.overloads[0].functionGlobals
742+
assert '__module_hash__' in f.overloads[0].realizedGlobals
743+
744+
# even if we execute this in another process, we should get the same function back
745+
# and it should have a __module_hash__ in its globals. We have to be careful about
746+
# this because we need the serializer to understand that __module_hash__ is special
747+
# and that the function implicitly references it.
748+
fFromOtherProcess = callFunctionInFreshProcess(makeFun, ('C',))
749+
assert fFromOtherProcess.overloads[0].functionGlobals['__module_hash__'] == 'C'
750+
assert fFromOtherProcess.overloads[0].realizedGlobals['__module_hash__'] == 'C'
751+
752+
# check that the identity hash we loaded is the same one we would get from a
753+
# subprocess reading it.
754+
assert (
755+
identityHash(type(fFromOtherProcess))
756+
== callFunctionInFreshProcess(makeFunIH, ('C',))
757+
)
758+
759+
760+
def test_module_hash_magic_value_on_untyped_function_preserved_by_serialization():
761+
with tempfile.TemporaryDirectory() as tempDir:
762+
763+
def makeFun(mh):
764+
"""Produce a dummy function with __module_hash__ of 'mh'
765+
766+
Note that we need to produce this function with 'backing code' so that
767+
the AST system can find it and serialize it.
768+
"""
769+
globalsDict = {}
770+
fname = os.path.join(tempDir, "code_" + mh + ".py")
771+
772+
pyCode = (
773+
f"from typed_python import Entrypoint\n"
774+
f"__module_hash__ = '{mh}'\n"
775+
f"def f(x):\n"
776+
f" return x\n"
777+
)
778+
779+
with open(fname, "w") as f:
780+
f.write(pyCode)
781+
782+
exec(compile(pyCode, fname, "exec"), globalsDict)
783+
return globalsDict['f']
784+
785+
fFromOtherProcess = callFunctionInFreshProcess(makeFun, ('C',))
786+
assert fFromOtherProcess.__globals__['__module_hash__'] == 'C'

0 commit comments

Comments
 (0)