|
4 | 4 | from .meta import Meta |
5 | 5 | from .errors import PinsInsecureReadError |
6 | 6 |
|
7 | | -from typing import Sequence |
| 7 | +from typing import Literal, Sequence, assert_never |
8 | 8 |
|
9 | 9 | # TODO: move IFileSystem out of boards, to fix circular import |
10 | 10 | # from .boards import IFileSystem |
|
15 | 15 |
|
16 | 16 |
|
17 | 17 | def _assert_is_pandas_df(x): |
18 | | - import pandas as pd |
| 18 | + df_family = _get_df_family(x) |
19 | 19 |
|
20 | | - if not isinstance(x, pd.DataFrame): |
| 20 | + if df_family != "pandas": |
21 | 21 | raise NotImplementedError( |
22 | 22 | "Currently only pandas.DataFrame can be saved to a CSV." |
23 | 23 | ) |
24 | 24 |
|
25 | 25 |
|
| 26 | +def _get_df_family(df) -> Literal["unknown", "pandas", "polars"]: |
| 27 | + try: |
| 28 | + import polars as pl |
| 29 | + except ModuleNotFoundError: |
| 30 | + is_polars_df = False |
| 31 | + else: |
| 32 | + is_polars_df = isinstance(df, pl.DataFrame) |
| 33 | + |
| 34 | + import pandas as pd |
| 35 | + |
| 36 | + is_pandas_df = isinstance(df, pd.DataFrame) |
| 37 | + |
| 38 | + if not is_polars_df and not is_pandas_df: |
| 39 | + return "unknown" |
| 40 | + if is_polars_df and is_pandas_df: # Hybrid DataFrame type! |
| 41 | + return "unknown" |
| 42 | + elif is_polars_df: |
| 43 | + return "polars" |
| 44 | + elif is_pandas_df: |
| 45 | + return "pandas" |
| 46 | + else: |
| 47 | + assert_never(df) |
| 48 | + |
| 49 | + |
26 | 50 | def load_path(meta, path_to_version): |
27 | 51 | # Check that only a single file name was given |
28 | 52 | fnames = [meta.file] if isinstance(meta.file, str) else meta.file |
@@ -174,9 +198,17 @@ def save_data( |
174 | 198 | ) |
175 | 199 |
|
176 | 200 | elif type == "parquet": |
177 | | - _assert_is_pandas_df(obj) |
178 | | - |
179 | | - obj.to_parquet(final_name) |
| 201 | + df_family = _get_df_family(obj) |
| 202 | + if df_family == "polars": |
| 203 | + obj.write_parquet(final_name) |
| 204 | + elif df_family == "pandas": |
| 205 | + obj.to_parquet(final_name) |
| 206 | + else: |
| 207 | + msg = ( |
| 208 | + "Currently only pandas.DataFrame and polars.DataFrame can be saved to " |
| 209 | + "a parquet file." |
| 210 | + ) |
| 211 | + raise NotImplementedError(msg) |
180 | 212 |
|
181 | 213 | elif type == "joblib": |
182 | 214 | import joblib |
|
0 commit comments