Skip to content

Commit 2a14619

Browse files
author
Braxton Mckee
committed
Ensure that we don't serialize module names on FunctionType when we're suppressing line info.
1 parent 38a5d79 commit 2a14619

File tree

6 files changed

+123
-3
lines changed

6 files changed

+123
-3
lines changed

typed_python/NullSerializationContext.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,7 @@ class NullSerializationContext : public SerializationContext {
4040
virtual bool isCompressionEnabled() const {
4141
return false;
4242
}
43+
virtual bool isLineInfoSuppressed() const {
44+
return false;
45+
}
4346
};

typed_python/PythonSerializationContext.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ void PythonSerializationContext::setFlags() {
4646
}
4747

4848
mSerializeHashSequence = ((PyObject*)serializeHashSequence) == Py_True;
49+
50+
PyObjectStealer encodeLineInformationForCode(PyObject_GetAttrString(mContextObj, "encodeLineInformationForCode"));
51+
52+
if (!encodeLineInformationForCode) {
53+
throw PythonExceptionSet();
54+
}
55+
56+
mSuppressLineInfo = ((PyObject*)encodeLineInformationForCode) == Py_False;
4957
}
5058

5159
std::string PythonSerializationContext::getNameForPyObj(PyObject* o) const {

typed_python/PythonSerializationContext.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ class PythonSerializationContext : public SerializationContext {
9393
bool isCompressionEnabled() const {
9494
return mCompressionEnabled;
9595
}
96+
97+
bool isLineInfoSuppressed() const {
98+
return mSuppressLineInfo;
99+
}
96100

97101
// should we serialize an integer in the order of the
98102
// hash sequence rather than the hash itself?
@@ -185,6 +189,8 @@ class PythonSerializationContext : public SerializationContext {
185189

186190
bool mCompressionEnabled;
187191

192+
bool mSuppressLineInfo;
193+
188194
bool mInternalizeTypeGroups;
189195

190196
bool mSerializeHashSequence;

typed_python/PythonSerializationContext_serialization.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,14 @@ void PythonSerializationContext::serializeNativeTypeInner(
824824
serializeNativeType(ftype->getClosureType(), b, 1);
825825
b.writeStringObject(2, ftype->name());
826826
b.writeStringObject(3, ftype->qualname());
827-
b.writeStringObject(4, ftype->moduleName());
827+
b.writeStringObject(
828+
4,
829+
// if we're suppressing line info then we also don't want to write
830+
// module names. Otherwise, it's impossible to get sha hashes of
831+
// code that gets relocated across file boundaries.
832+
b.getContext().isLineInfoSuppressed() ? std::string("") : ftype->moduleName()
833+
);
834+
828835
b.writeUnsignedVarintObject(5, ftype->isEntrypoint() ? 1 : 0);
829836
b.writeUnsignedVarintObject(6, ftype->isNocompile() ? 1 : 0);
830837

typed_python/SerializationContext.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ class SerializationContext {
3535
virtual Type* deserializeNativeType(DeserializationBuffer& b, size_t wireType) const = 0;
3636

3737
virtual bool isCompressionEnabled() const = 0;
38+
virtual bool isLineInfoSuppressed() const = 0;
3839
};

typed_python/types_serialization_test.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@
4545
Dict, Set, SerializationContext, EmbeddedMessage,
4646
serializeStream, deserializeStream, decodeSerializedObject,
4747
Forward, Final, Function, Entrypoint, TypeFunction, PointerTo,
48-
SubclassOf
48+
SubclassOf, NotCompiled
4949
)
5050

5151
from typed_python._types import (
5252
refcount, isRecursive, identityHash, buildPyFunctionObject,
53-
setFunctionClosure, typesAreEquivalent, recursiveTypeGroupDeepRepr
53+
setFunctionClosure, typesAreEquivalent, recursiveTypeGroupDeepRepr,
5454
)
5555

5656
module_level_testfun = dummy_test_module.testfunction
@@ -2735,6 +2735,61 @@ def test_serialization_independent_of_whether_function_is_hashed(self):
27352735

27362736
assert s1 == s2
27372737

2738+
def test_serialization_has_no_filename_reference(self):
2739+
def makeF(modulename):
2740+
with tempfile.TemporaryDirectory() as tempdir:
2741+
path = os.path.join(tempdir, modulename + ".py")
2742+
2743+
CONTENTS = (
2744+
"def f(x):\n"
2745+
" return x\n"
2746+
)
2747+
2748+
with open(path, "w") as f:
2749+
f.write(CONTENTS)
2750+
2751+
globals = {'__file__': path}
2752+
2753+
exec(
2754+
compile(CONTENTS, path, "exec"),
2755+
globals
2756+
)
2757+
2758+
s = SerializationContext()
2759+
return s.serialize(globals['f'])
2760+
2761+
f1 = SerializationContext().deserialize(callFunctionInFreshProcess(makeF, ('asdf',)))
2762+
f2 = SerializationContext().deserialize(callFunctionInFreshProcess(makeF, ('asdf2',)))
2763+
2764+
def checkSame(f1, f2):
2765+
s = SerializationContext().withoutCompression()
2766+
s2 = SerializationContext().withoutCompression().withoutLineInfoEncoded().withSerializeHashSequence()
2767+
2768+
assert s.serialize(f1) != s.serialize(f2)
2769+
2770+
if s2.serialize(f1) != s2.serialize(f2):
2771+
decoded1 = decodeSerializedObject(s2.serialize(f1))
2772+
decoded2 = decodeSerializedObject(s2.serialize(f2))
2773+
2774+
decoded1Print = pprint.PrettyPrinter(indent=2).pformat(decoded1).split("\n")
2775+
decoded2Print = pprint.PrettyPrinter(indent=2).pformat(decoded2).split("\n")
2776+
2777+
for i in range(len(decoded1Print)):
2778+
if decoded1Print[i] != decoded2Print[i]:
2779+
for j in range(max(0, i-5), i):
2780+
print(decoded1Print[j])
2781+
print("******************************** DIFFERENCE *******************")
2782+
print(decoded1Print[i])
2783+
print(decoded2Print[i])
2784+
print("***************************************************************")
2785+
break
2786+
2787+
assert s2.serialize(f1) == s2.serialize(f2)
2788+
2789+
checkSame(f1, f2)
2790+
checkSame(Entrypoint(f1), Entrypoint(f2))
2791+
checkSame(NotCompiled(f1), NotCompiled(f2))
2792+
27382793
def test_serialize_anonymous_class_with_defaults_and_nonempty(self):
27392794
class C1(Class):
27402795
x1 = Member(int, default_value=10, nonempty=True)
@@ -2821,3 +2876,43 @@ def test_serialization_context_names_for_pmap_functions(self):
28212876
from typed_python.lib.pmap import ensureThreads
28222877
sc = SerializationContext()
28232878
assert sc.nameForObject(type(ensureThreads)) is not None
2879+
2880+
def test_pmap_of_notcompiled_serialized_externally(self):
2881+
def makeC():
2882+
with tempfile.TemporaryDirectory() as tempdir:
2883+
path = os.path.join(tempdir, "asdf.py")
2884+
2885+
CONTENTS = (
2886+
"from typed_python import NotCompiled\n"
2887+
"@NotCompiled\n"
2888+
"def f(x) -> str:\n"
2889+
" return str(x)\n"
2890+
)
2891+
2892+
with open(path, "w") as f:
2893+
f.write(CONTENTS)
2894+
2895+
globals = {'__file__': path}
2896+
2897+
exec(
2898+
compile(CONTENTS, path, "exec"),
2899+
globals
2900+
)
2901+
2902+
s = SerializationContext()
2903+
return s.serialize(globals['f'])
2904+
2905+
serializedF = callFunctionInFreshProcess(makeC, ())
2906+
2907+
f = SerializationContext().deserialize(serializedF)
2908+
2909+
assert f(2) == "2"
2910+
2911+
from typed_python.lib.pmap import pmap
2912+
2913+
args = ListOf(int)(range(1000000))
2914+
2915+
t0 = time.time()
2916+
while time.time() - t0 < 10.0:
2917+
print("DO!")
2918+
pmap(args, f, str, minGranularity=10000)

0 commit comments

Comments
 (0)