|
4 | 4 | from .meta import Meta |
5 | 5 | from .errors import PinsInsecureReadError |
6 | 6 |
|
7 | | -from typing import Sequence |
| 7 | +from typing import Literal, Sequence, assert_never |
8 | 8 |
|
9 | 9 | # TODO: move IFileSystem out of boards, to fix circular import |
10 | 10 | # from .boards import IFileSystem |
|
15 | 15 |
|
16 | 16 |
|
17 | 17 | def _assert_is_pandas_df(x, file_type: str) -> None: |
18 | | - import pandas as pd |
| 18 | + df_family = _get_df_family(x) |
19 | 19 |
|
20 | | - if not isinstance(x, pd.DataFrame): |
| 20 | + if df_family != "pandas": |
21 | 21 | raise NotImplementedError( |
22 | 22 | f"Currently only pandas.DataFrame can be saved as type {file_type!r}." |
23 | 23 | ) |
24 | 24 |
|
25 | 25 |
|
| 26 | +def _get_df_family(df) -> Literal["unknown", "pandas", "polars"]: |
| 27 | + try: |
| 28 | + import polars as pl |
| 29 | + except ModuleNotFoundError: |
| 30 | + is_polars_df = False |
| 31 | + else: |
| 32 | + is_polars_df = isinstance(df, pl.DataFrame) |
| 33 | + |
| 34 | + import pandas as pd |
| 35 | + |
| 36 | + is_pandas_df = isinstance(df, pd.DataFrame) |
| 37 | + |
| 38 | + if not is_polars_df and not is_pandas_df: |
| 39 | + return "unknown" |
| 40 | + if is_polars_df and is_pandas_df: # Hybrid DataFrame type! |
| 41 | + return "unknown" |
| 42 | + elif is_polars_df: |
| 43 | + return "polars" |
| 44 | + elif is_pandas_df: |
| 45 | + return "pandas" |
| 46 | + else: |
| 47 | + assert_never(df) |
| 48 | + |
| 49 | + |
26 | 50 | def load_path(meta, path_to_version): |
27 | 51 | # Check that only a single file name was given |
28 | 52 | fnames = [meta.file] if isinstance(meta.file, str) else meta.file |
@@ -172,9 +196,17 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen |
172 | 196 | ) |
173 | 197 |
|
174 | 198 | elif type == "parquet": |
175 | | - _assert_is_pandas_df(obj, file_type=type) |
176 | | - |
177 | | - obj.to_parquet(final_name) |
| 199 | + df_family = _get_df_family(obj) |
| 200 | + if df_family == "polars": |
| 201 | + obj.write_parquet(final_name) |
| 202 | + elif df_family == "pandas": |
| 203 | + obj.to_parquet(final_name) |
| 204 | + else: |
| 205 | + msg = ( |
| 206 | + "Currently only pandas.DataFrame and polars.DataFrame can be saved to " |
| 207 | + "a parquet file." |
| 208 | + ) |
| 209 | + raise NotImplementedError(msg) |
178 | 210 |
|
179 | 211 | elif type == "joblib": |
180 | 212 | import joblib |
|
0 commit comments