@@ -29,7 +29,8 @@ def __init__(self, arr_seq, common_shape, dtype):
2929 self .lengths = list (arr_seq ._lengths )
3030 self .next_offset = arr_seq ._get_next_offset ()
3131 self .bytes_per_buf = arr_seq ._buffer_size * MEGABYTE
32- self .dtype = dtype
32+ # Use the passed dtype only if null data array
33+ self .dtype = dtype if arr_seq ._data .size == 0 else arr_seq ._data .dtype
3334 if arr_seq .common_shape != () and common_shape != arr_seq .common_shape :
3435 raise ValueError (
3536 "All dimensions, except the first one, must match exactly" )
@@ -89,24 +90,7 @@ def __init__(self, iterable=None, buffer_size=4):
8990 self ._is_view = True
9091 return
9192
92- # If possible try pre-allocating memory.
93- try :
94- iter_len = len (iterable )
95- except TypeError :
96- pass
97- else : # We do know the iterable length
98- if iter_len == 0 :
99- return
100- first_element = np .asarray (iterable [0 ])
101- n_elements = np .sum ([len (iterable [i ])
102- for i in range (len (iterable ))])
103- new_shape = (n_elements ,) + first_element .shape [1 :]
104- self ._data = np .empty (new_shape , dtype = first_element .dtype )
105-
106- for e in iterable :
107- self .append (e , cache_build = True )
108-
109- self .finalize_append ()
93+ self .extend (iterable )
11094
11195 @property
11296 def is_array_sequence (self ):
@@ -237,18 +221,23 @@ def extend(self, elements):
237221 The shape of the elements to be added must match the one of the data of
238222 this :class:`ArraySequence` except for the first dimension.
239223 """
240- if not is_array_sequence (elements ):
241- self .extend (self .__class__ (elements ))
242- return
243- if len (elements ) == 0 :
244- return
245- self ._build_cache = _BuildCache (self ,
246- elements .common_shape ,
247- elements .data .dtype )
248- self ._resize_data_to (self ._get_next_offset () + elements .nb_elements ,
249- self ._build_cache )
250- for element in elements :
251- self .append (element )
224+ # If possible try pre-allocating memory.
225+ try :
226+ iter_len = len (elements )
227+ except TypeError :
228+ pass
229+ else : # We do know the iterable length
230+ if iter_len == 0 :
231+ return
232+ e0 = np .asarray (elements [0 ])
233+ n_elements = np .sum ([len (e ) for e in elements ])
234+ self ._build_cache = _BuildCache (self , e0 .shape [1 :], e0 .dtype )
235+ self ._resize_data_to (self ._get_next_offset () + n_elements ,
236+ self ._build_cache )
237+
238+ for e in elements :
239+ self .append (e , cache_build = True )
240+
252241 self .finalize_append ()
253242
254243 def _extend_using_coroutine (self , buffer_size = 4 ):
0 commit comments