2828 is_allowed_extension_array_dtype ,
2929 is_duck_array ,
3030 is_duck_dask_array ,
31+ is_full_slice ,
3132 is_scalar ,
3233 is_valid_numpy_dtype ,
3334 to_0d_array ,
4344 from xarray .namedarray ._typing import _Shape , duckarray
4445 from xarray .namedarray .parallelcompat import ChunkManagerEntrypoint
4546
47+ BasicIndexerType = int | np .integer | slice
48+ OuterIndexerType = BasicIndexerType | np .ndarray [Any , np .dtype [np .integer ]]
49+
4650
4751@dataclass
4852class IndexSelResult :
@@ -300,19 +304,83 @@ def slice_slice(old_slice: slice, applied_slice: slice, size: int) -> slice:
300304 return slice (start , stop , step )
301305
302306
303- def _index_indexer_1d (old_indexer , applied_indexer , size : int ):
304- if isinstance (applied_indexer , slice ) and applied_indexer == slice (None ):
307+ def normalize_array (
308+ array : np .ndarray [Any , np .dtype [np .integer ]], size : int
309+ ) -> np .ndarray [Any , np .dtype [np .integer ]]:
310+ """
311+ Ensure that the given array only contains positive values.
312+
313+ Examples
314+ --------
315+ >>> normalize_array(np.array([-1, -2, -3, -4]), 10)
316+ array([9, 8, 7, 6])
317+ >>> normalize_array(np.array([-5, 3, 5, -1, 8]), 12)
318+ array([ 7, 3, 5, 11, 8])
319+ """
320+ if np .issubdtype (array .dtype , np .unsignedinteger ):
321+ return array
322+
323+ return np .where (array >= 0 , array , array + size )
324+
325+
326+ def slice_slice_by_array (
327+ old_slice : slice ,
328+ array : np .ndarray [Any , np .dtype [np .integer ]],
329+ size : int ,
330+ ) -> np .ndarray [Any , np .dtype [np .integer ]]:
331+ """Given a slice and the size of the dimension to which it will be applied,
332+ index it with an array to return a new array equivalent to applying
333+ the slices sequentially
334+
335+ Examples
336+ --------
337+ >>> slice_slice_by_array(slice(2, 10), np.array([1, 3, 5]), 12)
338+ array([3, 5, 7])
339+ >>> slice_slice_by_array(slice(1, None, 2), np.array([1, 3, 7, 8]), 20)
340+ array([ 3, 7, 15, 17])
341+ >>> slice_slice_by_array(slice(None, None, -1), np.array([2, 4, 7]), 20)
342+ array([17, 15, 12])
343+ """
344+ # to get a concrete slice, limited to the size of the array
345+ normalized_slice = normalize_slice (old_slice , size )
346+
347+ size_after_slice = len (range (* normalized_slice .indices (size )))
348+ normalized_array = normalize_array (array , size_after_slice )
349+
350+ new_indexer = normalized_array * normalized_slice .step + normalized_slice .start
351+
352+ if np .any (new_indexer >= size ):
353+ raise IndexError ("indices out of bounds" ) # TODO: more helpful error message
354+
355+ return new_indexer
356+
357+
358+ def _index_indexer_1d (
359+ old_indexer : OuterIndexerType ,
360+ applied_indexer : OuterIndexerType ,
361+ size : int ,
362+ ) -> OuterIndexerType :
363+ if is_full_slice (applied_indexer ):
305364 # shortcut for the usual case
306365 return old_indexer
366+ if is_full_slice (old_indexer ):
367+ # shortcut for full slices
368+ return applied_indexer
369+
370+ indexer : OuterIndexerType
307371 if isinstance (old_indexer , slice ):
308372 if isinstance (applied_indexer , slice ):
309373 indexer = slice_slice (old_indexer , applied_indexer , size )
310374 elif isinstance (applied_indexer , integer_types ):
311- indexer = range (* old_indexer .indices (size ))[applied_indexer ] # type: ignore[assignment]
375+ indexer = range (* old_indexer .indices (size ))[applied_indexer ]
312376 else :
313- indexer = _expand_slice (old_indexer , size )[ applied_indexer ]
314- else :
377+ indexer = slice_slice_by_array (old_indexer , applied_indexer , size )
378+ elif isinstance ( old_indexer , np . ndarray ) :
315379 indexer = old_indexer [applied_indexer ]
380+ else :
381+ # should be unreachable
382+ raise ValueError ("cannot index integers. Please open an issuec-" )
383+
316384 return indexer
317385
318386
@@ -389,7 +457,7 @@ class BasicIndexer(ExplicitIndexer):
389457
390458 __slots__ = ()
391459
392- def __init__ (self , key : tuple [int | np . integer | slice , ...]):
460+ def __init__ (self , key : tuple [BasicIndexerType , ...]):
393461 if not isinstance (key , tuple ):
394462 raise TypeError (f"key must be a tuple: { key !r} " )
395463
@@ -421,9 +489,7 @@ class OuterIndexer(ExplicitIndexer):
421489
422490 def __init__ (
423491 self ,
424- key : tuple [
425- int | np .integer | slice | np .ndarray [Any , np .dtype [np .generic ]], ...
426- ],
492+ key : tuple [BasicIndexerType | np .ndarray [Any , np .dtype [np .generic ]], ...],
427493 ):
428494 if not isinstance (key , tuple ):
429495 raise TypeError (f"key must be a tuple: { key !r} " )
@@ -629,7 +695,8 @@ def __init__(self, array: Any, key: ExplicitIndexer | None = None):
629695
630696 def _updated_key (self , new_key : ExplicitIndexer ) -> BasicIndexer | OuterIndexer :
631697 iter_new_key = iter (expanded_indexer (new_key .tuple , self .ndim ))
632- full_key = []
698+
699+ full_key : list [OuterIndexerType ] = []
633700 for size , k in zip (self .array .shape , self .key .tuple , strict = True ):
634701 if isinstance (k , integer_types ):
635702 full_key .append (k )
@@ -638,7 +705,7 @@ def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer:
638705 full_key_tuple = tuple (full_key )
639706
640707 if all (isinstance (k , integer_types + (slice ,)) for k in full_key_tuple ):
641- return BasicIndexer (full_key_tuple )
708+ return BasicIndexer (cast ( tuple [ BasicIndexerType , ...], full_key_tuple ) )
642709 return OuterIndexer (full_key_tuple )
643710
644711 @property
0 commit comments