Skip to content

Commit a6bf326

Browse files
committed
Ensure that Held Class types can reference themselves.
This fixes a nasty serialization bug where mutually recursive type groups can hold references to the same HeldClass in two places (once as a pyobj and once as a type). This was causing serialization of held classes that refer to themselves in their own functions to break. Really, we should ensure that TypeOrPyObj has a single way of referring to any given object - there shouldn't be a distinction in a MutuallyRecursiveTypeGroup between the Type and the PyObject representation of it.
1 parent efff26c commit a6bf326

File tree

6 files changed

+87
-12
lines changed

6 files changed

+87
-12
lines changed

typed_python/MutuallyRecursiveTypeGroup.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,26 @@ void visitCompilerVisibleTypesAndPyobjects(
181181
};
182182

183183
if (obj.type()) {
184+
Type* objType = obj.type();
185+
184186
hashVisit(ShaHash(1));
185187

186-
obj.type()->visitReferencedTypes(visit);
187-
obj.type()->visitCompilerVisiblePythonObjects(visit);
188-
obj.type()->visitCompilerVisibleInstances([&](Instance i) {
188+
objType->visitReferencedTypes(visit);
189+
190+
// ensure that held and non-held versions of Class are
191+
// always visible to each other.
192+
if (objType->isHeldClass()) {
193+
Type* t = ((HeldClass*)objType)->getClassType();
194+
visit(t);
195+
}
196+
197+
if (objType->isClass()) {
198+
Type* t = ((Class*)objType)->getHeldClass();
199+
visit(t);
200+
}
201+
202+
objType->visitCompilerVisiblePythonObjects(visit);
203+
objType->visitCompilerVisibleInstances([&](Instance i) {
189204
visit(i.type());
190205

191206
if (i.type()->getTypeCategory() == Type::TypeCategory::catPythonObjectOfType) {

typed_python/PyInstance.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ void PyInstance::copyConstructFromPythonInstanceConcrete(Type* eltType, instance
202202
aNewName = typeName;
203203
}
204204

205-
std::string verb;
206205
if (level == ConversionLevel::Signature) {
207206
throw std::logic_error("Object of type " + std::string(pyRepresentation->ob_type->tp_name) + " is not " + typeName);
208207
}

typed_python/PythonSerializationContext_serialization.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ void PythonSerializationContext::serializePythonObjectNamedOrAsObj(PyObject* o,
284284
void PythonSerializationContext::serializePythonObjectRepresentation(PyObject* representation, SerializationBuffer& b, size_t fieldNumber) const {
285285
PyEnsureGilAcquired acquireTheGil;
286286

287-
if (!PyTuple_Check(representation) || PyTuple_Size(representation) < 2
287+
if (!PyTuple_Check(representation) || PyTuple_Size(representation) < 2
288288
|| PyTuple_Size(representation) > 6) {
289289
throw std::runtime_error("representationFor should return None or a tuple with 2 to 6 things");
290290
}
@@ -492,7 +492,7 @@ void PythonSerializationContext::serializeMutuallyRecursiveTypeGroup(MutuallyRec
492492
}
493493

494494
if (representation != Py_None) {
495-
if (!PyTuple_Check(representation) || PyTuple_Size(representation) < 2
495+
if (!PyTuple_Check(representation) || PyTuple_Size(representation) < 2
496496
|| PyTuple_Size(representation) > 6) {
497497
throw std::runtime_error("representationFor should return None or a tuple with 2 to 6 things");
498498
}
@@ -630,13 +630,15 @@ void PythonSerializationContext::serializeMutuallyRecursiveTypeGroup(MutuallyRec
630630
// then HeldClass objects, ordered with all bases before
631631
// their children (so that when we deserialize, any base classes
632632
// are no longer forwards)
633-
std::map<HeldClass*, int> heldClasses;
633+
std::map<HeldClass*, std::set<int> > heldClasses;
634634
for (auto& indexAndObj: group->getIndexToObject()) {
635635
if (
636636
indexAndObj.second.typeOrPyobjAsType() &&
637637
indexAndObj.second.typeOrPyobjAsType()->isHeldClass()
638638
) {
639-
heldClasses[(HeldClass*)indexAndObj.second.typeOrPyobjAsType()] = indexAndObj.first;
639+
heldClasses[(HeldClass*)indexAndObj.second.typeOrPyobjAsType()].insert(
640+
indexAndObj.first
641+
);
640642
}
641643
}
642644

@@ -657,7 +659,9 @@ void PythonSerializationContext::serializeMutuallyRecursiveTypeGroup(MutuallyRec
657659
}
658660

659661
// now we can write this object.
660-
writeObjectBody(heldClasses[cls], cls);
662+
for (auto index: heldClasses[cls]) {
663+
writeObjectBody(index, cls);
664+
}
661665
};
662666

663667
for (auto classAndIndex: heldClasses) {

typed_python/_types.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2770,9 +2770,9 @@ PyObject *recursiveTypeGroupHash(PyObject* nullValue, PyObject* args) {
27702770
});
27712771
}
27722772

2773-
PyObject *recursiveTypeGroupDeepRepr(PyObject* nullValue, PyObject* args) {
2773+
PyObject *recursiveTypeGroupReprDeepFlag(PyObject* nullValue, PyObject* args, bool deep) {
27742774
if (PyTuple_Size(args) != 1) {
2775-
PyErr_SetString(PyExc_TypeError, "recursiveTypeGroupDeepRepr takes 1 positional argument");
2775+
PyErr_SetString(PyExc_TypeError, "recursiveTypeGroupRepr takes 1 positional argument");
27762776
return NULL;
27772777
}
27782778
PyObjectHolder a1(PyTuple_GetItem(args, 0));
@@ -2795,10 +2795,20 @@ PyObject *recursiveTypeGroupDeepRepr(PyObject* nullValue, PyObject* args) {
27952795
return incref(Py_None);
27962796
}
27972797

2798-
return PyUnicode_FromString(group->repr(true).c_str());
2798+
return PyUnicode_FromString(group->repr(deep).c_str());
27992799
});
28002800
}
28012801

2802+
PyObject *recursiveTypeGroupRepr(PyObject* nullValue, PyObject* args) {
2803+
return recursiveTypeGroupReprDeepFlag(nullValue, args, false);
2804+
}
2805+
2806+
2807+
PyObject *recursiveTypeGroupDeepRepr(PyObject* nullValue, PyObject* args) {
2808+
return recursiveTypeGroupReprDeepFlag(nullValue, args, true);
2809+
}
2810+
2811+
28022812
PyObject *recursiveTypeGroup(PyObject* nullValue, PyObject* args) {
28032813
if (PyTuple_Size(args) != 1) {
28042814
PyErr_SetString(PyExc_TypeError, "recursiveTypeGroup takes 1 positional argument");
@@ -3247,6 +3257,7 @@ static PyMethodDef module_methods[] = {
32473257
{"Forward", (PyCFunction)MakeForward, METH_VARARGS, NULL},
32483258
{"allForwardTypesResolved", (PyCFunction)allForwardTypesResolved, METH_VARARGS, NULL},
32493259
{"recursiveTypeGroup", (PyCFunction)recursiveTypeGroup, METH_VARARGS, NULL},
3260+
{"recursiveTypeGroupRepr", (PyCFunction)recursiveTypeGroupRepr, METH_VARARGS, NULL},
32503261
{"recursiveTypeGroupDeepRepr", (PyCFunction)recursiveTypeGroupDeepRepr, METH_VARARGS, NULL},
32513262
{"recursiveTypeGroupHash", (PyCFunction)recursiveTypeGroupHash, METH_VARARGS, NULL},
32523263
{"typesAndObjectsVisibleToCompilerFrom", (PyCFunction)typesAndObjectsVisibleToCompilerFrom, METH_VARARGS, NULL},

typed_python/type_identity_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,19 @@ def returnSerializedValue(filesToWrite, expression, printComments=False):
6363
)
6464

6565

66+
def test_class_and_held_class_in_group():
67+
class C(Class):
68+
pass
69+
70+
H = C.HeldClass
71+
72+
assert H in recursiveTypeGroup(H)
73+
assert C in recursiveTypeGroup(H)
74+
75+
assert H in recursiveTypeGroup(C)
76+
assert C in recursiveTypeGroup(C)
77+
78+
6679
def test_identity_of_register_types():
6780
assert isinstance(identityHash(UInt64), bytes)
6881
assert len(identityHash(UInt64)) == 20

typed_python/types_serialization_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2691,6 +2691,39 @@ def makeC():
26912691
assert C.anF() is C
26922692
assert C.anF.overloads[0].methodOf.Class is C
26932693

2694+
def test_held_class_serialized_externally(self):
2695+
def makeC():
2696+
with tempfile.TemporaryDirectory() as tempdir:
2697+
path = os.path.join(tempdir, "asdf.py")
2698+
2699+
CONTENTS = (
2700+
"from typed_python import Entrypoint, ListOf, Class, Held, Final\n"
2701+
"@Held\n"
2702+
"class C(Class, Final):\n"
2703+
" @staticmethod\n"
2704+
" def make():\n"
2705+
" return C()\n"
2706+
)
2707+
2708+
with open(path, "w") as f:
2709+
f.write(CONTENTS)
2710+
2711+
globals = {'__file__': path}
2712+
2713+
exec(
2714+
compile(CONTENTS, path, "exec"),
2715+
globals
2716+
)
2717+
2718+
s = SerializationContext()
2719+
return s.serialize(globals['C'])
2720+
2721+
serializedC = callFunctionInFreshProcess(makeC, ())
2722+
2723+
C = SerializationContext().deserialize(serializedC)
2724+
2725+
assert type(C.make()) is C
2726+
26942727
def test_serialization_independent_of_whether_function_is_hashed(self):
26952728
s = SerializationContext().withoutLineInfoEncoded().withoutCompression()
26962729

0 commit comments

Comments
 (0)