Skip to content

Commit ac9e625

Browse files
committed
Ensure that we don't arbitrarily duplicate HeldClass instances when calling functions.
In general, we want to ensure that if we pass a held class into a function, and that held class has a clear storage owner in the parent context, that we pass a reference to it into the function, so that we can modify the instance. For that to work, we need to not be making a bunch of copies of it.
1 parent f64696a commit ac9e625

File tree

4 files changed

+66
-0
lines changed

4 files changed

+66
-0
lines changed

typed_python/FunctionCallArgMapping.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ class FunctionCallArgMapping {
4747

4848
void coerceToType(py_obj_ptr& ptr, Type* target, ConversionLevel level) {
4949
try {
50+
Type* actualType = PyInstance::extractTypeFrom(ptr->ob_type);
51+
if (actualType == target) {
52+
// nothing to do!
53+
return;
54+
}
55+
5056
PyObject* coerced = PyInstance::initializePythonRepresentation(
5157
target,
5258
[&](instance_ptr data) {
@@ -328,6 +334,17 @@ class FunctionCallArgMapping {
328334
std::pair<Instance, bool> extractArgWithType(int argIx, Type* argType) const {
329335
if (mArgs[argIx].getIsNormalArg()) {
330336
try {
337+
Type* actualType = PyInstance::extractTypeFrom(mSingleValueArgs[argIx]->ob_type);
338+
if (actualType == argType) {
339+
// nothing to do!
340+
PyInstance* argAsPyInstance = ((PyInstance*)mSingleValueArgs[argIx]);
341+
342+
return std::make_pair(
343+
argAsPyInstance->mContainingInstance,
344+
true
345+
);
346+
}
347+
331348
return std::make_pair(
332349
Instance::createAndInitialize(argType, [&](instance_ptr p) {
333350
PyInstance::copyConstructFromPythonInstance(

typed_python/compiler/tests/held_class_compilation_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ def increment(self):
1616
self.x += 1
1717
self.y += 1
1818

19+
@Entrypoint
20+
def entrypointedIncrement(self):
21+
self.x += 1
22+
self.y += 1
23+
1924

2025
Complex = Forward("Complex")
2126

@@ -50,6 +55,37 @@ def checkCompiler(self, f, *args, **kwargs):
5055

5156
self.assertEqual(compiledOutput, interpretedOutput)
5257

58+
def test_held_class_entrypointed_methods(self):
59+
h1 = H()
60+
h2 = H()
61+
62+
h1.entrypointedIncrement()
63+
h2.increment()
64+
65+
assert h1.x == h2.x
66+
67+
def test_pass_held_to_function_with_signature(self):
68+
@Entrypoint
69+
def f(h: H):
70+
h.x = 100
71+
72+
@Entrypoint
73+
def g():
74+
h = H()
75+
f(h)
76+
return h
77+
78+
assert g().x == 100
79+
80+
def test_pass_held_by_ref_across_entrypoint(self):
81+
@Entrypoint
82+
def g(h):
83+
h.x = 100
84+
85+
h = H()
86+
g(h)
87+
assert h.x == 100
88+
5389
def test_compile_held_class(self):
5490
@Held
5591
class H(Class, Final):

typed_python/compiler/tests/held_class_interpreter_semantics_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,15 @@ def g(l):
166166

167167
runTest()
168168

169+
def testPassHeldClassToTypedArg(self):
170+
@Function
171+
def f(h: H):
172+
h.x = 100
173+
174+
h = H()
175+
f(h)
176+
assert h.x == 100
177+
169178
def testCallBoundMethodOnLeakedTemporaryCrashes(self):
170179
with self.assertRaisesRegex(Exception, "would have crashed"):
171180
def runTest():

typed_python/compiler/type_wrappers/held_class_wrapper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def _can_convert_to_type(self, otherType, explicit):
9292
def _can_convert_from_type(self, otherType, explicit):
9393
return False
9494

95+
def convert_to_type_with_target(self, context, instance, targetVal, conversionLevel, mayThrowOnFailure=False):
96+
print("Converting held class ", self, " to ", targetVal)
97+
return super().convert_to_type_with_target(context, instance, targetVal, conversionLevel, mayThrowOnFailure)
98+
9599
def bytesOfInitBitsForInstance(self, instance):
96100
return native_ast.const_uint64_expr(self.bytesOfInitBits)
97101

0 commit comments

Comments
 (0)