2121from typed_python .compiler .type_wrappers .dict_wrapper import DictWrapper
2222from typed_python import (
2323 Tuple , TypeFunction , Held , Member , Final , Class , ConstDict , PointerTo , Int32 , UInt8 ,
24- bytecount
24+ bytecount , TupleOf , Set
2525)
2626from typed_python .compiler .type_wrappers .compilable_builtin import CompilableBuiltin
2727
@@ -133,7 +133,7 @@ def const_dict_get(constDict, key, default):
133133 return default
134134
135135
136- def const_dict_contains (constDict , key ):
136+ def const_dict_index_of_key (constDict , key ):
137137 # perform a binary search
138138 lowIx = 0
139139 highIx = len (constDict )
@@ -148,9 +148,13 @@ def const_dict_contains(constDict, key):
148148 elif key < keyAtVal :
149149 highIx = mid
150150 else :
151- return True
151+ return mid
152+
153+ return - 1
152154
153- return False
155+
156+ def const_dict_contains (constDict , key ):
157+ return const_dict_index_of_key (constDict , key ) >= 0
154158
155159
156160class Malloc (CompilableBuiltin ):
@@ -176,7 +180,7 @@ def convert_call(self, context, instance, args, kwargs):
176180 return super ().convert_call (context , instance , args , kwargs )
177181
178182
179- def initialize_empty_const_dict (ptrToOutDict , kvCount ):
183+ def allocate_empty_const_dict (ptrToOutDict , kvCount ):
180184 byteCount = kvCount * (
181185 bytecount (ptrToOutDict .ElementType .KeyType )
182186 + bytecount (ptrToOutDict .ElementType .ValueType )
@@ -193,7 +197,137 @@ def initialize_empty_const_dict(ptrToOutDict, kvCount):
193197 # subpointers
194198 (target .cast (UInt8 ) + 16 ).cast (Int32 ).set (0 )
195199
196- ptrToOutDict .cast (PointerTo (None )).set (target )
200+ return target
201+
202+
203+ def const_dict_sub (ptrToOutDict , lhs , rhs ):
204+ # lhs is a dict and rhs is a TupleOf
205+ indicesToRemove = Set (int )()
206+
207+ for toRemove in rhs :
208+ ix = const_dict_index_of_key (lhs , toRemove )
209+ if ix >= 0 :
210+ indicesToRemove .add (ix )
211+
212+ if len (indicesToRemove ) == len (lhs ):
213+ # initialize it to zero
214+ ptrToOutDict .cast (PointerTo (None )).set (PointerTo (None )())
215+ return
216+
217+ ptrToOutDict .cast (PointerTo (None )).set (
218+ allocate_empty_const_dict (ptrToOutDict , len (lhs ) - len (indicesToRemove ))
219+ )
220+
221+ outIx = 0
222+
223+ for i in range (len (lhs )):
224+ if i not in indicesToRemove :
225+ ptrToOutDict .get ().initialize_kv_pair_unsafe (
226+ outIx ,
227+ lhs .get_key_by_index_unsafe (i ),
228+ lhs .get_value_by_index_unsafe (i )
229+ )
230+ outIx += 1
231+
232+ assert outIx == len (lhs ) - len (indicesToRemove )
233+ ptrToOutDict .get ().set_kv_count_unsafe (outIx )
234+
235+
236+ def const_dict_add (ptrToOutDict , lhs , rhs ):
237+ if not lhs :
238+ ptrToOutDict .initialize (rhs )
239+ return
240+
241+ if not rhs :
242+ ptrToOutDict .initialize (lhs )
243+ return
244+
245+ lhsIx = 0
246+ rhsIx = 0
247+
248+ lenLhs = len (lhs )
249+ lenRhs = len (rhs )
250+
251+ lhsValuesToKeep = 0
252+
253+ while lhsIx < lenLhs :
254+ if rhsIx < lenRhs :
255+ lhsKey = lhs .get_key_by_index_unsafe (lhsIx )
256+ rhsKey = rhs .get_key_by_index_unsafe (rhsIx )
257+
258+ if lhsKey < rhsKey :
259+ lhsIx += 1
260+ lhsValuesToKeep += 1
261+ elif rhsKey < lhsKey :
262+ rhsIx += 1
263+ else :
264+ lhsIx += 1
265+ rhsIx += 1
266+ else :
267+ lhsValuesToKeep += 1
268+ lhsIx += 1
269+
270+ try :
271+ ptrToOutDict .cast (PointerTo (None )).set (
272+ allocate_empty_const_dict (ptrToOutDict , len (rhs ) + lhsValuesToKeep )
273+ )
274+
275+ outIx = 0
276+ lhsIx = 0
277+ rhsIx = 0
278+
279+ while lhsIx < lenLhs or rhsIx < lenRhs :
280+ if rhsIx < lenRhs and lhsIx < lenLhs :
281+ lhsKey = lhs .get_key_by_index_unsafe (lhsIx )
282+ rhsKey = rhs .get_key_by_index_unsafe (rhsIx )
283+
284+ if lhsKey < rhsKey :
285+ ptrToOutDict .get ().initialize_kv_pair_unsafe (
286+ outIx ,
287+ lhsKey ,
288+ lhs .get_value_by_index_unsafe (lhsIx )
289+ )
290+ lhsIx += 1
291+ outIx += 1
292+ else :
293+ ptrToOutDict .get ().initialize_kv_pair_unsafe (
294+ outIx ,
295+ rhsKey ,
296+ rhs .get_value_by_index_unsafe (rhsIx )
297+ )
298+ rhsIx += 1
299+ outIx += 1
300+
301+ if lhsKey == rhsKey :
302+ lhsIx += 1
303+ elif lhsIx < lenLhs :
304+ lhsKey = lhs .get_key_by_index_unsafe (lhsIx )
305+
306+ ptrToOutDict .get ().initialize_kv_pair_unsafe (
307+ outIx ,
308+ lhsKey ,
309+ lhs .get_value_by_index_unsafe (lhsIx )
310+ )
311+ lhsIx += 1
312+ outIx += 1
313+ else :
314+ assert rhsIx < lenRhs
315+
316+ rhsKey = rhs .get_key_by_index_unsafe (rhsIx )
317+ ptrToOutDict .get ().initialize_kv_pair_unsafe (
318+ outIx ,
319+ rhsKey ,
320+ rhs .get_value_by_index_unsafe (rhsIx )
321+ )
322+ rhsIx += 1
323+ outIx += 1
324+
325+ assert outIx == len (rhs ) + lhsValuesToKeep
326+ ptrToOutDict .get ().set_kv_count_unsafe (outIx )
327+
328+ except : # noqa
329+ ptrToOutDict .destroy ()
330+ raise
197331
198332
199333def initialize_const_dict_from_mappable (ptrToOutDict , mappable , mayThrow ):
@@ -210,7 +344,9 @@ def initialize_const_dict_from_mappable(ptrToOutDict, mappable, mayThrow):
210344 we failed, and ptrToOutDict will not point to an initialized
211345 dictionary.
212346 """
213- initialize_empty_const_dict (ptrToOutDict , len (mappable ))
347+ ptrToOutDict .cast (PointerTo (None )).set (
348+ allocate_empty_const_dict (ptrToOutDict , len (mappable ))
349+ )
214350
215351 try :
216352 count = 0
@@ -516,6 +652,41 @@ def convert_bin_op(self, context, left, op, right, inplace):
516652 if op .matches .GtE :
517653 return context .call_py_function (const_dict_gte , (left , right ), {})
518654
655+ if op .matches .Add :
656+ newDict = context .allocateUninitializedSlot (self .typeRepresentation )
657+
658+ res = context .call_py_function (
659+ const_dict_add , (newDict .asPointer (), left , right ), {}
660+ )
661+
662+ if res is None :
663+ return res
664+
665+ context .markUninitializedSlotInitialized (newDict )
666+
667+ return newDict
668+
669+ if op .matches .Sub :
670+ right = right .convert_to_type (
671+ TupleOf (self .typeRepresentation .KeyType ), ConversionLevel .Implicit
672+ )
673+
674+ if right is None :
675+ return None
676+
677+ newDict = context .allocateUninitializedSlot (self .typeRepresentation )
678+
679+ res = context .call_py_function (
680+ const_dict_sub , (newDict .asPointer (), left , right ), {}
681+ )
682+
683+ if res is None :
684+ return res
685+
686+ context .markUninitializedSlotInitialized (newDict )
687+
688+ return newDict
689+
519690 return super ().convert_bin_op (context , left , op , right , inplace )
520691
521692 def convert_bin_op_reverse (self , context , left , op , right , inplace ):
0 commit comments