|
1 | 1 | from __future__ import annotations |
2 | 2 | from pydantic import BaseModel |
3 | 3 | from typing import ( |
4 | | - Tuple, Callable, Union, List, TypeVar, Generic, Dict, Optional |
| 4 | + Tuple, Callable, Union, List, TypeVar, Generic, Dict, Optional, Iterable |
5 | 5 | ) |
6 | 6 | from pathlib import Path |
7 | 7 | from functools import lru_cache |
@@ -89,9 +89,16 @@ def from_dataframe(dataframe: pd.DataFrame) -> Dataset[pd.Series]: |
89 | 89 | get_item=lambda df, index: df.iloc[index], |
90 | 90 | ) |
91 | 91 |
|
92 | | - def __getitem__(self: Dataset[T], index: int) -> T: |
93 | | - '''Get an example ``T`` from the ``Dataset[T]``''' |
94 | | - return self.get_item(self.dataframe, index) |
| 92 | + def __getitem__( |
| 93 | + self: Dataset[T], |
| 94 | + select: Union[int, slice, Iterable, Callable[[pd.DataFrame], Iterable[int]]] |
| 95 | + ) -> Union[T, Dataset[T]]: |
| 96 | + '''Get selection from the ``Dataset[T]``''' |
| 97 | + if np.issubdtype(type(select), np.integer): |
| 98 | + return self.get_item(self.dataframe, select) |
| 99 | + else: |
| 100 | + dataframe = self.dataframe.iloc[select] |
| 101 | + return self.replace(dataframe=dataframe, length=len(dataframe)) |
95 | 102 |
|
96 | 103 | def __len__(self): |
97 | 104 | return self.length |
@@ -198,27 +205,8 @@ def subset( |
198 | 205 | ... )[-1] |
199 | 206 | 2 |
200 | 207 | ''' |
201 | | - |
202 | | - mask = mask_fn(self.dataframe) |
203 | | - if isinstance(mask, list): |
204 | | - mask = np.array(mask) |
205 | | - elif isinstance(mask, pd.Series): |
206 | | - mask = mask.values |
207 | | - |
208 | | - if len(mask.shape) != 1: |
209 | | - raise AssertionError('Expected single dimension in mask') |
210 | | - |
211 | | - if len(mask) != len(self): |
212 | | - raise AssertionError( |
213 | | - 'Expected mask to have the same length as the dataset' |
214 | | - ) |
215 | | - |
216 | | - indices = np.argwhere(mask).squeeze(1) |
217 | | - return Dataset( |
218 | | - dataframe=self.dataframe.iloc[indices], |
219 | | - length=len(indices), |
220 | | - get_item=self.get_item, |
221 | | - ) |
| 208 | + dataframe = self.dataframe[mask_fn(self.dataframe)] |
| 209 | + return self.replace(dataframe=dataframe, length=len(dataframe)) |
222 | 210 |
|
223 | 211 | def split( |
224 | 212 | self, |
|
0 commit comments