Skip to content

Commit 224a0fb

Browse files
committed
Ensure that if 'A' is an alternative with X as a choice, that Set(A.X) works correctly.
1 parent c5f033f commit 224a0fb

File tree

9 files changed

+81
-54
lines changed

9 files changed

+81
-54
lines changed

typed_python/AlternativeType.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -123,35 +123,6 @@ bool Alternative::cmp(instance_ptr left, instance_ptr right, int pyComparisonOp,
123123
return m_subtypes[record_l.which].second->cmp(record_l.data, record_r.data, pyComparisonOp, suppressExceptions);
124124
}
125125

126-
//static
127-
bool Alternative::cmpStatic(Alternative* altT, instance_ptr left, instance_ptr right, int64_t pyComparisonOp) {
128-
if (altT->m_all_alternatives_empty) {
129-
if (*(uint8_t*)left < *(uint8_t*)right) {
130-
return cmpResultToBoolForPyOrdering(pyComparisonOp, -1);
131-
}
132-
if (*(uint8_t*)left > *(uint8_t*)right) {
133-
return cmpResultToBoolForPyOrdering(pyComparisonOp, 1);
134-
}
135-
return cmpResultToBoolForPyOrdering(pyComparisonOp, 0);
136-
}
137-
138-
layout& record_l = **(layout**)left;
139-
layout& record_r = **(layout**)right;
140-
141-
if ( &record_l == &record_r ) {
142-
return cmpResultToBoolForPyOrdering(pyComparisonOp, 0);
143-
}
144-
145-
if (record_l.which < record_r.which) {
146-
return cmpResultToBoolForPyOrdering(pyComparisonOp, -1);
147-
}
148-
if (record_l.which > record_r.which) {
149-
return cmpResultToBoolForPyOrdering(pyComparisonOp, 1);
150-
}
151-
152-
return altT->m_subtypes[record_l.which].second->cmp(record_l.data, record_r.data, pyComparisonOp, false);
153-
}
154-
155126
void Alternative::repr(instance_ptr self, ReprAccumulator& stream, bool isStr) {
156127
PushReprState isNew(stream, self);
157128

typed_python/AlternativeType.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,7 @@ class Alternative : public Type {
182182
}
183183
}
184184

185-
186185
bool cmp(instance_ptr left, instance_ptr right, int pyComparisonOp, bool suppressExceptions);
187-
static bool cmpStatic(Alternative* altT, instance_ptr left, instance_ptr right, int64_t pyComparisonOp);
188186

189187
template<class buf_t>
190188
void serialize(instance_ptr self, buf_t& buffer, size_t fieldNumber) {

typed_python/ConcreteAlternativeType.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class ConcreteAlternative : public Type {
103103
}
104104

105105
bool cmp(instance_ptr left, instance_ptr right, int pyComparisonOp, bool suppressExceptions) {
106-
return m_alternative->cmp(left,right, pyComparisonOp, suppressExceptions);
106+
return m_alternative->cmp(left, right, pyComparisonOp, suppressExceptions);
107107
}
108108

109109
void constructor(instance_ptr self);

typed_python/_runtime.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,13 +466,15 @@ extern "C" {
466466
return StringType::cmpStatic(lhs, rhs);
467467
}
468468

469-
bool np_runtime_alternative_cmp(Alternative* tp, instance_ptr lhs, instance_ptr rhs, int64_t pyComparisonOp) {
470-
return Alternative::cmpStatic(tp, lhs, rhs, pyComparisonOp);
471-
}
469+
bool np_instance_cmp(Type* t, instance_ptr lhs, instance_ptr rhs, int64_t pyComparisonOp) {
470+
try {
471+
return t->cmp(lhs, rhs, pyComparisonOp, false);
472+
} catch(std::exception& e) {
473+
PyEnsureGilAcquired getTheGil;
472474

473-
bool np_runtime_class_cmp(Class* tp, instance_ptr lhs, instance_ptr rhs, int64_t pyComparisonOp) {
474-
PyEnsureGilAcquired acquireTheGil;
475-
return Class::cmpStatic(tp, lhs, rhs, pyComparisonOp);
475+
PyErr_SetString(PyExc_TypeError, e.what());
476+
throw PythonExceptionSet();
477+
}
476478
}
477479

478480
StringType::layout* np_typePtrToName(Type* t) {

typed_python/compiler/tests/alternative_compilation_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import psutil
2525
from math import trunc, floor, ceil
2626
from typed_python.compiler.type_wrappers.class_wrapper import classCouldBeInstanceOf
27+
from typed_python.compiler.type_wrappers.hash_table_implementation import NativeHash
2728

2829

2930
class TestAlternativeCompilation(unittest.TestCase):
@@ -1724,3 +1725,32 @@ def g(x: A):
17241725
return f(x)
17251726

17261727
assert g(A.X(x=10)) == 10
1728+
1729+
def test_alternative_hashing(self):
1730+
A = Alternative("A", A=dict(a=int))
1731+
1732+
assert NativeHash.callHash(A.A, A.A(a=12)) == hash(A.A(a=12))
1733+
assert NativeHash.callHash(A, A.A(a=12)) == hash(A.A(a=12))
1734+
1735+
def test_alternative_equality(self):
1736+
AB = Alternative("AB", A=dict(a=int), B=dict(b=str))
1737+
1738+
@Entrypoint
1739+
def eqConcrete(a1: AB.A, a2: AB.A):
1740+
return a1 == a2
1741+
1742+
@Entrypoint
1743+
def eqMixed(a1: AB, a2: AB.A):
1744+
return a1 == a2
1745+
1746+
@Entrypoint
1747+
def eq(a1: AB, a2: AB):
1748+
return a1 == a2
1749+
1750+
assert not eq(AB.A(a=1), AB.A(a=2))
1751+
assert not eqMixed(AB.A(a=1), AB.A(a=2))
1752+
assert not eqConcrete(AB.A(a=1), AB.A(a=2))
1753+
1754+
assert eq(AB.A(a=1), AB.A(a=1))
1755+
assert eqMixed(AB.A(a=1), AB.A(a=1))
1756+
assert eqConcrete(AB.A(a=1), AB.A(a=1))

typed_python/compiler/tests/set_compilation_test.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typed_python import (
1616
Set, ListOf, Entrypoint, Compiled, Tuple, TupleOf, NamedTuple, Dict, ConstDict, OneOf,
17-
UInt8
17+
UInt8, Alternative
1818
)
1919
from typed_python.compiler.type_wrappers.set_wrapper import (
2020
set_union, set_intersection, set_difference,
@@ -1191,3 +1191,28 @@ def callC(y: OneOf(None, Set(float))):
11911191
return C(y)
11921192

11931193
callC([1])
1194+
1195+
def test_set_of_specific_alternative(self):
1196+
A = Alternative(
1197+
"A",
1198+
A=dict(a=int),
1199+
B=dict(b=str),
1200+
)
1201+
1202+
s = Set(A.A)()
1203+
1204+
@Entrypoint
1205+
def add(s, k):
1206+
s.add(k)
1207+
1208+
@Entrypoint
1209+
def contains(s, k):
1210+
return k in s
1211+
1212+
someInts = [1, 10, 2, 6, 123, 13]
1213+
1214+
for i in someInts:
1215+
add(s, A.A(a=i))
1216+
assert A.A(a=i) in s
1217+
assert contains(s, A.A(a=i))
1218+
assert not contains(s, A.A(a=-i))

typed_python/compiler/type_wrappers/alternative_wrapper.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,13 @@ def convert_comparison(self, context, lhs, op, rhs):
6161
if py_code < 0:
6262
return super().convert_comparison(context, lhs, op, rhs)
6363

64-
if not lhs.isReference:
65-
lhs = context.pushMove(lhs)
64+
if _types.isBinaryCompatible(lhs.expr_type.typeRepresentation, rhs.expr_type.typeRepresentation):
65+
lhs = lhs.ensureIsReference()
66+
rhs = rhs.ensureIsReference()
6667

67-
if not rhs.isReference:
68-
rhs = context.pushMove(rhs)
69-
70-
if lhs.expr_type.typeRepresentation == rhs.expr_type.typeRepresentation:
7168
return context.pushPod(
7269
bool,
73-
runtime_functions.alternative_cmp.call(
70+
runtime_functions.instance_cmp.call(
7471
context.getTypePointer(lhs.expr_type.typeRepresentation),
7572
lhs.expr.cast(VoidPtr),
7673
rhs.expr.cast(VoidPtr),

typed_python/compiler/type_wrappers/hash_table_implementation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ def convert_call(self, context, instance, args, kwargs):
7777

7878
return super().convert_call(context, instance, args, kwargs)
7979

80+
@staticmethod
81+
def callHash(T, x):
82+
from typed_python import Entrypoint
83+
84+
@Entrypoint
85+
def hashIt(x):
86+
return NativeHash(T)(x)
87+
88+
return hashIt(x)
89+
8090

8191
def table_add_slot(instance, itemHash, slot):
8292
if (instance._hash_table_count * 2 + 1 > instance._hash_table_size or

typed_python/compiler/type_wrappers/runtime_functions.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -526,18 +526,12 @@ def unaryPyobjCallTarget(name, retType=Void.pointer()):
526526
Void.pointer(), Void.pointer()
527527
)
528528

529-
alternative_cmp = externalCallTarget(
530-
"np_runtime_alternative_cmp",
529+
instance_cmp = externalCallTarget(
530+
"np_instance_cmp",
531531
Bool,
532532
Void.pointer(), Void.pointer(), Void.Pointer(), Int64
533533
)
534534

535-
class_cmp = externalCallTarget(
536-
"np_runtime_class_cmp",
537-
Bool,
538-
UInt64, Void.pointer(), Void.Pointer(), Int64
539-
)
540-
541535
string_chr_int64 = externalCallTarget(
542536
"nativepython_runtime_string_chr",
543537
Void.pointer(),

0 commit comments

Comments
 (0)