@@ -59,6 +59,9 @@ def __getitem__(self, key):
5959 # Key was not a valid index/slice after all.
6060 return self .store [key ] # Will raise the proper error.
6161
62+ def __contains__ (self , key ):
63+ return key in self .store
64+
6265 def __delitem__ (self , key ):
6366 del self .store [key ]
6467
@@ -90,7 +93,7 @@ class PerArrayDict(SliceableDataDict):
9093 Positional and keyword arguments, passed straight through the ``dict``
9194 constructor.
9295 """
93- def __init__ (self , n_rows = None , * args , ** kwargs ):
96+ def __init__ (self , n_rows = 0 , * args , ** kwargs ):
9497 self .n_rows = n_rows
9598 super (PerArrayDict , self ).__init__ (* args , ** kwargs )
9699
@@ -105,13 +108,17 @@ def __setitem__(self, key, value):
105108 raise ValueError ("data_per_streamline must be a 2D array." )
106109
107110 # We make sure there is the right amount of values
108- if self .n_rows is not None and len (value ) != self .n_rows :
111+ if self .n_rows > 0 and len (value ) != self .n_rows :
109112 msg = ("The number of values ({0}) should match n_elements "
110113 "({1})." ).format (len (value ), self .n_rows )
111114 raise ValueError (msg )
112115
113116 self .store [key ] = value
114117
118+ def _extend_entry (self , key , value ):
119+ """ Appends the `value` to the entry specified by `key`. """
120+ self [key ] = np .concatenate ([self [key ], value ])
121+
115122 def extend (self , other ):
116123 """ Appends the elements of another :class:`PerArrayDict`.
117124
@@ -131,16 +138,20 @@ def extend(self, other):
131138 -----
132139 The keys in both dictionaries must be the same.
133140 """
134- if sorted (self .keys ()) != sorted (other .keys ()):
141+ if (len (self ) > 0 and len (other ) > 0
142+ and sorted (self .keys ()) != sorted (other .keys ())):
135143 msg = ("Entry mismatched between the two PerArrayDict objects."
136144 " This PerArrayDict contains '{0}' whereas the other "
137145 " contains '{1}'." ).format (sorted (self .keys ()),
138146 sorted (other .keys ()))
139147 raise ValueError (msg )
140148
141149 self .n_rows += other .n_rows
142- for key in self .keys ():
143- self [key ] = np .concatenate ([self [key ], other [key ]])
150+ for key in other .keys ():
151+ if key not in self :
152+ self [key ] = other [key ]
153+ else :
154+ self ._extend_entry (key , other [key ])
144155
145156
146157class PerArraySequenceDict (PerArrayDict ):
@@ -158,43 +169,16 @@ def __setitem__(self, key, value):
158169 value = ArraySequence (value )
159170
160171 # We make sure there is the right amount of data.
161- if (self .n_rows is not None and
162- value .total_nb_rows != self .n_rows ):
172+ if self .n_rows > 0 and value .total_nb_rows != self .n_rows :
163173 msg = ("The number of values ({0}) should match "
164174 "({1})." ).format (value .total_nb_rows , self .n_rows )
165175 raise ValueError (msg )
166176
167177 self .store [key ] = value
168178
169- def extend (self , other ):
170- """ Appends the elements of another :class:`PerArraySequenceDict`.
171-
172- That is, for each entry in this dictionary, we append the elements
173- coming from the other dictionary at the corresponding entry.
174-
175- Parameters
176- ----------
177- other : :class:`PerArraySequenceDict` object
178- Its data will be appended to the data of this dictionary.
179-
180- Returns
181- -------
182- None
183-
184- Notes
185- -----
186- The keys in both dictionaries must be the same.
187- """
188- if sorted (self .keys ()) != sorted (other .keys ()):
189- msg = ("Key mismatched between the two PerArrayDict objects."
190- " This PerArrayDict contains '{0}' whereas the other "
191- " contains '{1}'." ).format (sorted (self .keys ()),
192- sorted (other .keys ()))
193- raise ValueError (msg )
194-
195- self .n_rows += other .n_rows
196- for key in self .keys ():
197- self [key ].extend (other [key ])
179+ def _extend_entry (self , key , value ):
180+ """ Appends the `value` to the entry specified by `key`. """
181+ self [key ].extend (value )
198182
199183
200184class LazyDict (collections .MutableMapping ):
0 commit comments