Skip to content

Commit 1f97560

Browse files
committed
ConstDict __add__ and __sub__ are compilable.
1 parent 224a0fb commit 1f97560

File tree

4 files changed

+230
-12
lines changed

4 files changed

+230
-12
lines changed

typed_python/compiler/tests/alternative_compilation_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typed_python import (
1616
TypeFunction, Int16, UInt64, Float32, Alternative, Forward,
17-
Dict, ConstDict, ListOf, Compiled, OneOf
17+
Dict, ListOf, Compiled, OneOf
1818
)
1919
import typed_python._types as _types
2020
from typed_python import Entrypoint
@@ -399,7 +399,7 @@ def test_compile_alternative_reverse_methods(self):
399399
)
400400

401401
values = [1, Int16(1), UInt64(1), 1.234, Float32(1.234), True, "abc",
402-
ListOf(int)((1, 2)), ConstDict(str, str)({"a": "1"})]
402+
ListOf(int)((1, 2))]
403403
for v in values:
404404
T = type(v)
405405

@@ -1059,7 +1059,7 @@ def test_compile_simple_alternative_reverse_methods(self):
10591059
)
10601060

10611061
values = [1, Int16(1), UInt64(1), 1.234, Float32(1.234), True, "abc",
1062-
ListOf(int)((1, 2)), ConstDict(str, str)({"a": "1"})]
1062+
ListOf(int)((1, 2))]
10631063
for v in values:
10641064
T = type(v)
10651065

typed_python/compiler/tests/class_compilation_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typed_python import (
2121
Class,
2222
Dict,
23-
ConstDict,
2423
TupleOf,
2524
ListOf,
2625
Member,
@@ -1064,7 +1063,6 @@ class C(Class, Final):
10641063
True,
10651064
"abc",
10661065
ListOf(int)((1, 2)),
1067-
ConstDict(str, str)({"a": "1"}),
10681066
]
10691067
for v in values:
10701068
T = type(v)

typed_python/compiler/tests/const_dict_compilation_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,52 @@ def f():
413413
return ConstDict(int, str)(res)
414414

415415
assert f() == Entrypoint(f)()
416+
417+
def test_const_dict_add(self):
418+
@Entrypoint
419+
def addDicts(d1, d2):
420+
return d1 + d2
421+
422+
T = ConstDict(int, int)
423+
424+
someDicts = [
425+
{1: 2, 3: 4},
426+
{1: 4},
427+
{1: 2, 2: 4, 3: 8},
428+
{2: 3, 5: 1, 4: 9, 3: 6},
429+
{},
430+
{7: 8},
431+
{7: 8, 2: 3},
432+
{7: 8, 2: 4},
433+
]
434+
435+
someDicts = [T(x) for x in someDicts]
436+
437+
for d1 in someDicts:
438+
for d2 in someDicts:
439+
assert d1 + d2 == addDicts(d1, d2)
440+
441+
def test_const_dict_sub(self):
442+
@Entrypoint
443+
def subKey(dct, key):
444+
return dct - key
445+
446+
T = ConstDict(int, int)
447+
448+
someDicts = [
449+
{1: 2, 3: 4},
450+
{1: 2, 3: 4, 5: 6},
451+
{1: 4},
452+
{1: 2, 2: 4, 3: 8},
453+
{2: 3, 5: 1, 4: 9, 3: 6},
454+
{},
455+
{7: 8},
456+
{7: 8, 2: 3},
457+
{7: 8, 2: 4},
458+
]
459+
460+
someDicts = [T(x) for x in someDicts]
461+
462+
for d1 in someDicts:
463+
for d2 in someDicts:
464+
assert d1 - TupleOf(int)(d2) == subKey(d1, TupleOf(int)(d2))

typed_python/compiler/type_wrappers/const_dict_wrapper.py

Lines changed: 178 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typed_python.compiler.type_wrappers.dict_wrapper import DictWrapper
2222
from typed_python import (
2323
Tuple, TypeFunction, Held, Member, Final, Class, ConstDict, PointerTo, Int32, UInt8,
24-
bytecount
24+
bytecount, TupleOf, Set
2525
)
2626
from typed_python.compiler.type_wrappers.compilable_builtin import CompilableBuiltin
2727

@@ -133,7 +133,7 @@ def const_dict_get(constDict, key, default):
133133
return default
134134

135135

136-
def const_dict_contains(constDict, key):
136+
def const_dict_index_of_key(constDict, key):
137137
# perform a binary search
138138
lowIx = 0
139139
highIx = len(constDict)
@@ -148,9 +148,13 @@ def const_dict_contains(constDict, key):
148148
elif key < keyAtVal:
149149
highIx = mid
150150
else:
151-
return True
151+
return mid
152+
153+
return -1
152154

153-
return False
155+
156+
def const_dict_contains(constDict, key):
157+
return const_dict_index_of_key(constDict, key) >= 0
154158

155159

156160
class Malloc(CompilableBuiltin):
@@ -176,7 +180,7 @@ def convert_call(self, context, instance, args, kwargs):
176180
return super().convert_call(context, instance, args, kwargs)
177181

178182

179-
def initialize_empty_const_dict(ptrToOutDict, kvCount):
183+
def allocate_empty_const_dict(ptrToOutDict, kvCount):
180184
byteCount = kvCount * (
181185
bytecount(ptrToOutDict.ElementType.KeyType)
182186
+ bytecount(ptrToOutDict.ElementType.ValueType)
@@ -193,7 +197,137 @@ def initialize_empty_const_dict(ptrToOutDict, kvCount):
193197
# subpointers
194198
(target.cast(UInt8) + 16).cast(Int32).set(0)
195199

196-
ptrToOutDict.cast(PointerTo(None)).set(target)
200+
return target
201+
202+
203+
def const_dict_sub(ptrToOutDict, lhs, rhs):
204+
# lhs is a dict and rhs is a TupleOf
205+
indicesToRemove = Set(int)()
206+
207+
for toRemove in rhs:
208+
ix = const_dict_index_of_key(lhs, toRemove)
209+
if ix >= 0:
210+
indicesToRemove.add(ix)
211+
212+
if len(indicesToRemove) == len(lhs):
213+
# initialize it to zero
214+
ptrToOutDict.cast(PointerTo(None)).set(PointerTo(None)())
215+
return
216+
217+
ptrToOutDict.cast(PointerTo(None)).set(
218+
allocate_empty_const_dict(ptrToOutDict, len(lhs) - len(indicesToRemove))
219+
)
220+
221+
outIx = 0
222+
223+
for i in range(len(lhs)):
224+
if i not in indicesToRemove:
225+
ptrToOutDict.get().initialize_kv_pair_unsafe(
226+
outIx,
227+
lhs.get_key_by_index_unsafe(i),
228+
lhs.get_value_by_index_unsafe(i)
229+
)
230+
outIx += 1
231+
232+
assert outIx == len(lhs) - len(indicesToRemove)
233+
ptrToOutDict.get().set_kv_count_unsafe(outIx)
234+
235+
236+
def const_dict_add(ptrToOutDict, lhs, rhs):
237+
if not lhs:
238+
ptrToOutDict.initialize(rhs)
239+
return
240+
241+
if not rhs:
242+
ptrToOutDict.initialize(lhs)
243+
return
244+
245+
lhsIx = 0
246+
rhsIx = 0
247+
248+
lenLhs = len(lhs)
249+
lenRhs = len(rhs)
250+
251+
lhsValuesToKeep = 0
252+
253+
while lhsIx < lenLhs:
254+
if rhsIx < lenRhs:
255+
lhsKey = lhs.get_key_by_index_unsafe(lhsIx)
256+
rhsKey = rhs.get_key_by_index_unsafe(rhsIx)
257+
258+
if lhsKey < rhsKey:
259+
lhsIx += 1
260+
lhsValuesToKeep += 1
261+
elif rhsKey < lhsKey:
262+
rhsIx += 1
263+
else:
264+
lhsIx += 1
265+
rhsIx += 1
266+
else:
267+
lhsValuesToKeep += 1
268+
lhsIx += 1
269+
270+
try:
271+
ptrToOutDict.cast(PointerTo(None)).set(
272+
allocate_empty_const_dict(ptrToOutDict, len(rhs) + lhsValuesToKeep)
273+
)
274+
275+
outIx = 0
276+
lhsIx = 0
277+
rhsIx = 0
278+
279+
while lhsIx < lenLhs or rhsIx < lenRhs:
280+
if rhsIx < lenRhs and lhsIx < lenLhs:
281+
lhsKey = lhs.get_key_by_index_unsafe(lhsIx)
282+
rhsKey = rhs.get_key_by_index_unsafe(rhsIx)
283+
284+
if lhsKey < rhsKey:
285+
ptrToOutDict.get().initialize_kv_pair_unsafe(
286+
outIx,
287+
lhsKey,
288+
lhs.get_value_by_index_unsafe(lhsIx)
289+
)
290+
lhsIx += 1
291+
outIx += 1
292+
else:
293+
ptrToOutDict.get().initialize_kv_pair_unsafe(
294+
outIx,
295+
rhsKey,
296+
rhs.get_value_by_index_unsafe(rhsIx)
297+
)
298+
rhsIx += 1
299+
outIx += 1
300+
301+
if lhsKey == rhsKey:
302+
lhsIx += 1
303+
elif lhsIx < lenLhs:
304+
lhsKey = lhs.get_key_by_index_unsafe(lhsIx)
305+
306+
ptrToOutDict.get().initialize_kv_pair_unsafe(
307+
outIx,
308+
lhsKey,
309+
lhs.get_value_by_index_unsafe(lhsIx)
310+
)
311+
lhsIx += 1
312+
outIx += 1
313+
else:
314+
assert rhsIx < lenRhs
315+
316+
rhsKey = rhs.get_key_by_index_unsafe(rhsIx)
317+
ptrToOutDict.get().initialize_kv_pair_unsafe(
318+
outIx,
319+
rhsKey,
320+
rhs.get_value_by_index_unsafe(rhsIx)
321+
)
322+
rhsIx += 1
323+
outIx += 1
324+
325+
assert outIx == len(rhs) + lhsValuesToKeep
326+
ptrToOutDict.get().set_kv_count_unsafe(outIx)
327+
328+
except: # noqa
329+
ptrToOutDict.destroy()
330+
raise
197331

198332

199333
def initialize_const_dict_from_mappable(ptrToOutDict, mappable, mayThrow):
@@ -210,7 +344,9 @@ def initialize_const_dict_from_mappable(ptrToOutDict, mappable, mayThrow):
210344
we failed, and ptrToOutDict will not point to an initialized
211345
dictionary.
212346
"""
213-
initialize_empty_const_dict(ptrToOutDict, len(mappable))
347+
ptrToOutDict.cast(PointerTo(None)).set(
348+
allocate_empty_const_dict(ptrToOutDict, len(mappable))
349+
)
214350

215351
try:
216352
count = 0
@@ -516,6 +652,41 @@ def convert_bin_op(self, context, left, op, right, inplace):
516652
if op.matches.GtE:
517653
return context.call_py_function(const_dict_gte, (left, right), {})
518654

655+
if op.matches.Add:
656+
newDict = context.allocateUninitializedSlot(self.typeRepresentation)
657+
658+
res = context.call_py_function(
659+
const_dict_add, (newDict.asPointer(), left, right), {}
660+
)
661+
662+
if res is None:
663+
return res
664+
665+
context.markUninitializedSlotInitialized(newDict)
666+
667+
return newDict
668+
669+
if op.matches.Sub:
670+
right = right.convert_to_type(
671+
TupleOf(self.typeRepresentation.KeyType), ConversionLevel.Implicit
672+
)
673+
674+
if right is None:
675+
return None
676+
677+
newDict = context.allocateUninitializedSlot(self.typeRepresentation)
678+
679+
res = context.call_py_function(
680+
const_dict_sub, (newDict.asPointer(), left, right), {}
681+
)
682+
683+
if res is None:
684+
return res
685+
686+
context.markUninitializedSlotInitialized(newDict)
687+
688+
return newDict
689+
519690
return super().convert_bin_op(context, left, op, right, inplace)
520691

521692
def convert_bin_op_reverse(self, context, left, op, right, inplace):

0 commit comments

Comments
 (0)