Skip to content

Commit 8b734ed

Browse files
committed
Ensure we can add untyped tuples to List/TupleOf in the compiler.
1 parent 1f59e75 commit 8b734ed

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

typed_python/compiler/tests/list_of_compilation_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,3 +787,15 @@ def assignIt(l, x):
787787

788788
with self.assertRaises(TypeError):
789789
aList[0] = "hi"
790+
791+
def test_add_untyped_tuple(self):
792+
@Entrypoint
793+
def addIt(aTup: ListOf(int), x: int):
794+
return aTup + (x,)
795+
796+
@Entrypoint
797+
def addItLst(aTup: ListOf(int), x: int):
798+
return aTup + [x]
799+
800+
assert addIt((1, 2), 3) == [1, 2, 3]
801+
assert addItLst((1, 2), 3) == [1, 2, 3]

typed_python/compiler/tests/tuple_of_compilation_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,28 @@ def sliceIt(aTup: TupleOf(int), x: int):
232232
return aTup[:x]
233233

234234
assert sliceIt((1, 2, 3), 1) == (1,)
235+
236+
def test_add_untyped_tuple(self):
237+
@Entrypoint
238+
def addIt(aTup: TupleOf(int), x: int):
239+
return aTup + (x,)
240+
241+
@Entrypoint
242+
def addItLst(aTup: TupleOf(int), x: int):
243+
return aTup + [x]
244+
245+
assert addIt((1, 2), 3) == (1, 2, 3)
246+
assert addItLst((1, 2), 3) == (1, 2, 3)
247+
248+
def test_add_untyped_tuple_reversed(self):
249+
@Entrypoint
250+
def addIt(aTup: TupleOf(int), x: int):
251+
return (x,) + aTup
252+
253+
@Entrypoint
254+
def addItLst(aTup: TupleOf(int), x: int):
255+
return [x] + aTup
256+
257+
assert addIt((1, 2), 3) == (3, 1, 2)
258+
assert addItLst((1, 2), 3) == (3, 1, 2)
259+

typed_python/compiler/type_wrappers/tuple_of_wrapper.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,24 @@ def concatenate_tuple_or_list(l, r):
218218
return result
219219

220220

221+
def concatenate_tuple_or_list_reversed(l, r):
222+
result = PreReservedTupleOrList(type(r))(len(l) + len(r))
223+
224+
ix = 0
225+
226+
for item in l:
227+
result._initializeItemUnsafe(ix, item)
228+
ix += 1
229+
result.setSizeUnsafe(ix)
230+
231+
for item in r:
232+
result._initializeItemUnsafe(ix, item)
233+
ix += 1
234+
result.setSizeUnsafe(ix)
235+
236+
return result
237+
238+
221239
def list_or_tupleof_of_slice(aList, start, stop, step):
222240
if step is None:
223241
return list_or_tupleof_of_slice(aList, start, stop, 1)
@@ -338,7 +356,7 @@ def generateNativeDestructorFunction(self, context, out, inst):
338356
)
339357

340358
def convert_bin_op(self, context, left, op, right, inplace):
341-
if issubclass(right.expr_type.typeRepresentation, (TupleOf, ListOf)):
359+
if issubclass(right.expr_type.typeRepresentation, (TupleOf, ListOf, Tuple, NamedTuple)):
342360
if op.matches.Add:
343361
return context.call_py_function(concatenate_tuple_or_list, (left, right), {})
344362

@@ -359,6 +377,10 @@ def convert_bin_op(self, context, left, op, right, inplace):
359377
return super().convert_bin_op(context, left, op, right, inplace)
360378

361379
def convert_bin_op_reverse(self, context, right, op, left, inplace):
380+
if issubclass(right.expr_type.typeRepresentation, (TupleOf, ListOf, Tuple, NamedTuple)):
381+
if op.matches.Add:
382+
return context.call_py_function(concatenate_tuple_or_list_reversed, (left, right), {})
383+
362384
if op.matches.In or op.matches.NotIn:
363385
left = left.convert_to_type(self.typeRepresentation.ElementType, ConversionLevel.Implicit)
364386

0 commit comments

Comments
 (0)