Skip to content

Commit efff26c

Browse files
committed
Ensure that Held arguments to entrypointed methods work correctly.
1 parent ac9e625 commit efff26c

File tree

5 files changed

+135
-37
lines changed

5 files changed

+135
-37
lines changed

typed_python/FunctionCallArgMapping.hpp

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,50 @@ class FunctionCallArgMapping {
3131
FunctionCallArgMapping& operator=(const FunctionCallArgMapping&) = delete;
3232

3333
public:
34+
class FunctionArg {
35+
public:
36+
FunctionArg() :
37+
mInstance(),
38+
mValid(false),
39+
mRawPtr(nullptr)
40+
{
41+
}
42+
43+
FunctionArg(instance_ptr inData) :
44+
mInstance(),
45+
mValid(true),
46+
mRawPtr(inData)
47+
{
48+
}
49+
50+
FunctionArg(Instance i) :
51+
mInstance(i),
52+
mValid(true),
53+
mRawPtr(nullptr)
54+
{
55+
}
56+
57+
bool isValid() const {
58+
return mValid;
59+
}
60+
61+
instance_ptr dataPtr() const {
62+
if (mRawPtr) {
63+
return mRawPtr;
64+
}
65+
66+
return mInstance.data();
67+
}
68+
69+
private:
70+
bool mValid;
71+
72+
Instance mInstance;
73+
74+
instance_ptr mRawPtr;
75+
};
76+
77+
3478
FunctionCallArgMapping(const Function::Overload& overload) :
3579
mArgs(overload.getArgs()),
3680
mCurrentArgumentIx(0),
@@ -331,36 +375,38 @@ class FunctionCallArgMapping {
331375
return res;
332376
}
333377

334-
std::pair<Instance, bool> extractArgWithType(int argIx, Type* argType) const {
378+
FunctionArg extractArgWithType(int argIx, Type* argType) const {
335379
if (mArgs[argIx].getIsNormalArg()) {
336380
try {
337381
Type* actualType = PyInstance::extractTypeFrom(mSingleValueArgs[argIx]->ob_type);
338382
if (actualType == argType) {
339383
// nothing to do!
340384
PyInstance* argAsPyInstance = ((PyInstance*)mSingleValueArgs[argIx]);
341385

342-
return std::make_pair(
343-
argAsPyInstance->mContainingInstance,
344-
true
386+
if (argAsPyInstance->mTemporaryRefTo) {
387+
return FunctionArg(argAsPyInstance->mTemporaryRefTo);
388+
}
389+
390+
return FunctionArg(
391+
argAsPyInstance->mContainingInstance
345392
);
346393
}
347394

348-
return std::make_pair(
395+
return FunctionArg(
349396
Instance::createAndInitialize(argType, [&](instance_ptr p) {
350397
PyInstance::copyConstructFromPythonInstance(
351398
argType, p, mSingleValueArgs[argIx], ConversionLevel::Signature
352399
);
353-
}),
354-
true
400+
})
355401
);
356402
} catch(PythonExceptionSet& s) {
357403
// failed to convert, but keep going
358404
PyErr_Clear();
359-
return std::pair<Instance, bool>(Instance(), false);
405+
return FunctionArg();
360406
}
361407
catch(...) {
362408
// not a valid conversion
363-
return std::pair<Instance, bool>(Instance(), false);
409+
return FunctionArg();
364410
}
365411
} else if (mArgs[argIx].getIsStarArg()) {
366412
if (argType->getTypeCategory() != Type::TypeCategory::catTuple) {
@@ -370,29 +416,28 @@ class FunctionCallArgMapping {
370416
Tuple* tup = (Tuple*)argType;
371417

372418
if (mStarArgValues.size() != tup->getTypes().size()) {
373-
return std::pair<Instance, bool>(Instance(), false);
419+
return FunctionArg();
374420
}
375421

376422
try {
377-
return std::make_pair(
423+
return FunctionArg(
378424
Instance::createAndInitialize(tup, [&](instance_ptr p) {
379425
tup->constructor(p, [&](instance_ptr subElt, int tupArg) {
380426
PyInstance::copyConstructFromPythonInstance(
381427
tup->getTypes()[tupArg], subElt, mStarArgValues[tupArg],
382428
ConversionLevel::Signature
383429
);
384430
});
385-
}),
386-
true
431+
})
387432
);
388433
} catch(PythonExceptionSet& s) {
389434
// failed to convert, but keep going
390435
PyErr_Clear();
391-
return std::pair<Instance, bool>(Instance(), false);
436+
return FunctionArg();
392437
}
393438
catch(...) {
394439
// not a valid conversion
395-
return std::pair<Instance, bool>(Instance(), false);
440+
return FunctionArg();
396441
}
397442
} else if (mArgs[argIx].getIsKwarg()) {
398443
if (argType->getTypeCategory() != Type::TypeCategory::catNamedTuple) {
@@ -402,17 +447,17 @@ class FunctionCallArgMapping {
402447
NamedTuple* tup = (NamedTuple*)argType;
403448

404449
if (mKwargValues.size() != tup->getTypes().size()) {
405-
return std::pair<Instance, bool>(Instance(), false);
450+
return FunctionArg();
406451
}
407452

408453
for (long k = 0; k < mKwargValues.size(); k++) {
409454
if (mKwargValues[k].first != tup->getNames()[k]) {
410-
return std::pair<Instance, bool>(Instance(), false);
455+
return FunctionArg();
411456
}
412457
}
413458

414459
try {
415-
return std::make_pair(
460+
return FunctionArg(
416461
Instance::createAndInitialize(tup, [&](instance_ptr p) {
417462
tup->constructor(p, [&](instance_ptr subElt, int tupArg) {
418463
PyInstance::copyConstructFromPythonInstance(
@@ -422,17 +467,16 @@ class FunctionCallArgMapping {
422467
ConversionLevel::Signature
423468
);
424469
});
425-
}),
426-
true
470+
})
427471
);
428472
} catch(PythonExceptionSet& s) {
429473
// failed to convert, but keep going
430474
PyErr_Clear();
431-
return std::pair<Instance, bool>(Instance(), false);
475+
return FunctionArg();
432476
}
433477
catch(...) {
434478
// not a valid conversion
435-
return std::pair<Instance, bool>(Instance(), false);
479+
return FunctionArg();
436480
}
437481
}
438482

typed_python/PyFunctionInstance.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ std::pair<bool, PyObject*> PyFunctionInstance::dispatchFunctionCallToCompiledSpe
514514
throw std::runtime_error("Malformed function specialization: missing a return type.");
515515
}
516516

517-
std::vector<Instance> instances;
517+
std::vector<FunctionCallArgMapping::FunctionArg> mappingArgs;
518518

519519
// first, see if we can short-circuit
520520
for (long k = 0; k < overload.getArgs().size(); k++) {
@@ -532,10 +532,10 @@ std::pair<bool, PyObject*> PyFunctionInstance::dispatchFunctionCallToCompiledSpe
532532
auto arg = overload.getArgs()[k];
533533
Type* argType = specialization.getArgTypes()[k];
534534

535-
std::pair<Instance, bool> res = mapper.extractArgWithType(k, argType);
535+
FunctionCallArgMapping::FunctionArg res = mapper.extractArgWithType(k, argType);
536536

537-
if (res.second) {
538-
instances.push_back(res.first);
537+
if (res.isValid()) {
538+
mappingArgs.push_back(res);
539539
} else {
540540
return std::pair<bool, PyObject*>(false, (PyObject*)nullptr);
541541
}
@@ -552,8 +552,8 @@ std::pair<bool, PyObject*> PyFunctionInstance::dispatchFunctionCallToCompiledSpe
552552
args.push_back(closureCells.back().data());
553553
}
554554

555-
for (auto& i: instances) {
556-
args.push_back(i.data());
555+
for (auto& i: mappingArgs) {
556+
args.push_back(i.dataPtr());
557557
}
558558

559559
auto functionPtr = specialization.getFuncPtr();

typed_python/compiler/tests/class_compilation_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,19 @@ def callStr(x):
366366
self.assertEqual(callRepr(ClassWithReprAndStr()), "repr")
367367
self.assertEqual(callStr(ClassWithReprAndStr()), "str")
368368

369+
def test_class_str_without_override(self):
370+
class NoStr(Class):
371+
x = Member(int)
372+
y = Member(float)
373+
374+
@Entrypoint
375+
def callStr(x):
376+
return str(x)
377+
378+
n = NoStr()
379+
380+
assert str(n) == callStr(n)
381+
369382
def test_compiled_class_subclass_layout(self):
370383
class BaseClass(Class):
371384
x = Member(int)

typed_python/compiler/tests/held_class_compilation_test.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import unittest
22
import time
33
from typed_python import _types
4-
from typed_python import Class, Final, ListOf, Held, Member, Entrypoint, Forward
4+
from typed_python import Class, Final, ListOf, Held, Member, Entrypoint, Forward, pointerTo
55

66

77
@Held
88
class H(Class, Final):
9-
x = Member(int)
10-
y = Member(float)
9+
x = Member(int, nonempty=True)
10+
y = Member(float, nonempty=True)
1111

1212
def f(self):
1313
return self.x + self.y
@@ -21,6 +21,14 @@ def entrypointedIncrement(self):
2121
self.x += 1
2222
self.y += 1
2323

24+
@Entrypoint
25+
def getX(self):
26+
return self.x
27+
28+
@Entrypoint
29+
def pointerToSelf(self):
30+
return pointerTo(self)
31+
2432

2533
Complex = Forward("Complex")
2634

@@ -55,14 +63,43 @@ def checkCompiler(self, f, *args, **kwargs):
5563

5664
self.assertEqual(compiledOutput, interpretedOutput)
5765

66+
def test_held_class_pointer_to_self(self):
67+
@Entrypoint
68+
def callPointerTo(h):
69+
return pointerTo(h)
70+
71+
h = H(x=2, y=3)
72+
73+
assert callPointerTo(h) == pointerTo(h)
74+
assert h.pointerToSelf() == pointerTo(h)
75+
5876
def test_held_class_entrypointed_methods(self):
59-
h1 = H()
60-
h2 = H()
77+
h1 = H(x=2, y=3)
78+
h2 = H(x=2, y=3)
6179

6280
h1.entrypointedIncrement()
6381
h2.increment()
6482

6583
assert h1.x == h2.x
84+
assert h1.getX() == h1.x
85+
86+
def test_stringify_held_class(self):
87+
h = H(x=2, y=3)
88+
89+
@Entrypoint
90+
def callStr(h):
91+
return str(h)
92+
93+
assert str(h) == callStr(h)
94+
95+
def test_pointer_to_held_class_compiles(self):
96+
h = H(x=2, y=3)
97+
98+
@Entrypoint
99+
def getPtr(h):
100+
return pointerTo(h)
101+
102+
assert pointerTo(h) == getPtr(h)
66103

67104
def test_pass_held_to_function_with_signature(self):
68105
@Entrypoint

typed_python/compiler/type_wrappers/held_class_wrapper.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
ClassOrAlternativeWrapperMixin
2121
)
2222

23-
from typed_python import _types, RefTo
23+
from typed_python import _types, RefTo, PointerTo
2424

2525
import typed_python.compiler.native_ast as native_ast
2626
import typed_python.compiler
@@ -40,6 +40,7 @@ def __init__(self, t):
4040
self.heldClassType = t
4141
self.classType = t.Class
4242
self.refToType = RefTo(t)
43+
self.ptrToType = PointerTo(t)
4344

4445
self.classTypeWrapper = typeWrapper(t.Class)
4546
self.nameToIndex = self.classTypeWrapper.nameToIndex
@@ -58,6 +59,10 @@ def __init__(self, t):
5859
count=_types.bytecount(self.heldClassType)
5960
)
6061

62+
def convert_pointerTo(self, context, instance):
63+
assert instance.isReference
64+
return instance.changeType(self.ptrToType, isReferenceOverride=False)
65+
6166
def fieldGuaranteedInitialized(self, ix):
6267
if self.classType.ClassMembers[
6368
self.classType.MemberNames[ix]
@@ -87,13 +92,12 @@ def convert_attribute_pointerTo(self, context, pointerInstance, attribute):
8792
return super().convert_attribute(context, pointerInstance, attribute)
8893

8994
def _can_convert_to_type(self, otherType, explicit):
90-
return False
95+
return super()._can_convert_to_type(otherType, explicit)
9196

9297
def _can_convert_from_type(self, otherType, explicit):
93-
return False
98+
return super()._can_convert_from_type(otherType, explicit)
9499

95100
def convert_to_type_with_target(self, context, instance, targetVal, conversionLevel, mayThrowOnFailure=False):
96-
print("Converting held class ", self, " to ", targetVal)
97101
return super().convert_to_type_with_target(context, instance, targetVal, conversionLevel, mayThrowOnFailure)
98102

99103
def bytesOfInitBitsForInstance(self, instance):

0 commit comments

Comments
 (0)