Skip to content

Commit 55854ea

Browse files
committed
Ensure that we walk into closures of function instances bound inside of types.
This demonstrates that we need to rethink how we handle TP Instances in the MRTG since its clearly possible for TP instances to be directly reachable as constants in this kind of code, and we're not including that in the compiler hash correctly right now.
1 parent bef26a6 commit 55854ea

File tree

4 files changed

+102
-1
lines changed

4 files changed

+102
-1
lines changed

typed_python/CompilerVisibleObjectVisitor.hpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,39 @@ class CompilerVisibleObjectVisitor {
523523
);
524524
}
525525

526+
template<class visitor_type>
527+
static void walkInstance(
528+
Type* objType,
529+
instance_ptr instance,
530+
const visitor_type& visitor
531+
) {
532+
if (objType->isComposite()) {
533+
CompositeType* compType = (CompositeType*)objType;
534+
for (long k = 0; k < compType->getTypes().size(); k++) {
535+
walkInstance(
536+
compType->getTypes()[k],
537+
compType->eltPtr(instance, k),
538+
visitor
539+
);
540+
}
541+
return;
542+
}
543+
544+
if (objType->isPyCell()) {
545+
static PyCellType* pct = PyCellType::Make();
546+
547+
PyObject* o = pct->getPyObj(instance);
548+
549+
if (!PyCell_Check(o) || !PyCell_Get(o)) {
550+
return;
551+
}
552+
553+
visitor.visitTopo(PyCell_Get(o));
554+
return;
555+
}
556+
}
557+
558+
526559
template<class visitor_type>
527560
static void walk(
528561
TypeOrPyobj obj,
@@ -576,6 +609,19 @@ class CompilerVisibleObjectVisitor {
576609
if (argType) {
577610
visitor.visitHash(ShaHash(2));
578611
visitor.visitTopo(argType);
612+
613+
if (argType->isFunction()) {
614+
// visit the function's closure
615+
Function* funcT = (Function*)argType;
616+
instance_ptr dataPtr = ((PyInstance*)obj.pyobj())->dataPtr();
617+
618+
walkInstance(
619+
funcT->getClosureType(),
620+
dataPtr,
621+
visitor
622+
);
623+
}
624+
579625
return;
580626
}
581627

typed_python/FunctionType.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,10 @@ class Function : public Type {
876876
);
877877
}
878878

879-
visitor.visitNamedTopo(nameAndGlobal.first, nameAndGlobal.second);
879+
visitor.visitNamedTopo(
880+
nameAndGlobal.first,
881+
nameAndGlobal.second
882+
);
880883
}
881884

882885
if (mReturnType) {

typed_python/Type.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ class Type {
197197
return m_typeCategory == catPythonObjectOfType;
198198
}
199199

200+
bool isPyCell() const {
201+
return m_typeCategory == catPyCell;
202+
}
203+
200204
bool isClass() const {
201205
return m_typeCategory == catClass;
202206
}

typed_python/type_identity_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,3 +784,51 @@ def makeFun(mh):
784784

785785
fFromOtherProcess = callFunctionInFreshProcess(makeFun, ('C',))
786786
assert fFromOtherProcess.__globals__['__module_hash__'] == 'C'
787+
788+
789+
@TypeFunction
790+
def RefsAValueInEntrypointedStaticmethod(i):
791+
class C:
792+
@staticmethod
793+
@Entrypoint
794+
def f():
795+
return i
796+
return C
797+
798+
799+
@TypeFunction
800+
def RefsAValueInStaticmethod(i):
801+
class C:
802+
@staticmethod
803+
def f():
804+
return i
805+
return C
806+
807+
808+
@TypeFunction
809+
def RefsAValueInStaticmethodOnClass(i):
810+
class C(Class):
811+
@staticmethod
812+
def f():
813+
return i
814+
return C
815+
816+
817+
def test_type_function_identity_referencing_int_in_function_only():
818+
assert RefsAValueInEntrypointedStaticmethod(1).f() == 1
819+
assert RefsAValueInEntrypointedStaticmethod(2).f() == 2
820+
821+
assert (
822+
identityHash(RefsAValueInStaticmethod(1))
823+
!= identityHash(RefsAValueInStaticmethod(2))
824+
)
825+
826+
assert (
827+
identityHash(RefsAValueInEntrypointedStaticmethod(1))
828+
!= identityHash(RefsAValueInEntrypointedStaticmethod(2))
829+
)
830+
831+
assert (
832+
identityHash(RefsAValueInStaticmethodOnClass(1))
833+
!= identityHash(RefsAValueInStaticmethodOnClass(2))
834+
)

0 commit comments

Comments
 (0)