|
1 | 1 | from pathlib import Path |
2 | | -from typing import Sequence |
| 2 | +from typing import Literal, Sequence, assert_never |
3 | 3 |
|
4 | 4 | from .config import PINS_ENV_INSECURE_READ, get_allow_pickle_read |
5 | 5 | from .errors import PinsInsecureReadError |
|
14 | 14 |
|
15 | 15 |
|
16 | 16 | def _assert_is_pandas_df(x, file_type: str) -> None: |
17 | | - import pandas as pd |
| 17 | + df_family = _get_df_family(x) |
18 | 18 |
|
19 | | - if not isinstance(x, pd.DataFrame): |
| 19 | + if df_family != "pandas": |
20 | 20 | raise NotImplementedError( |
21 | 21 | f"Currently only pandas.DataFrame can be saved as type {file_type!r}." |
22 | 22 | ) |
23 | 23 |
|
24 | 24 |
|
| 25 | +def _get_df_family(df) -> Literal["unknown", "pandas", "polars"]: |
| 26 | + try: |
| 27 | + import polars as pl |
| 28 | + except ModuleNotFoundError: |
| 29 | + is_polars_df = False |
| 30 | + else: |
| 31 | + is_polars_df = isinstance(df, pl.DataFrame) |
| 32 | + |
| 33 | + import pandas as pd |
| 34 | + |
| 35 | + is_pandas_df = isinstance(df, pd.DataFrame) |
| 36 | + |
| 37 | + if not is_polars_df and not is_pandas_df: |
| 38 | + return "unknown" |
| 39 | + if is_polars_df and is_pandas_df: # Hybrid DataFrame type! |
| 40 | + return "unknown" |
| 41 | + elif is_polars_df: |
| 42 | + return "polars" |
| 43 | + elif is_pandas_df: |
| 44 | + return "pandas" |
| 45 | + else: |
| 46 | + assert_never(df) |
| 47 | + |
| 48 | + |
25 | 49 | def load_path(meta, path_to_version): |
26 | 50 | # Check that only a single file name was given |
27 | 51 | fnames = [meta.file] if isinstance(meta.file, str) else meta.file |
@@ -171,9 +195,17 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen |
171 | 195 | ) |
172 | 196 |
|
173 | 197 | elif type == "parquet": |
174 | | - _assert_is_pandas_df(obj, file_type=type) |
175 | | - |
176 | | - obj.to_parquet(final_name) |
| 198 | + df_family = _get_df_family(obj) |
| 199 | + if df_family == "polars": |
| 200 | + obj.write_parquet(final_name) |
| 201 | + elif df_family == "pandas": |
| 202 | + obj.to_parquet(final_name) |
| 203 | + else: |
| 204 | + msg = ( |
| 205 | + "Currently only pandas.DataFrame and polars.DataFrame can be saved to " |
| 206 | + "a parquet file." |
| 207 | + ) |
| 208 | + raise NotImplementedError(msg) |
177 | 209 |
|
178 | 210 | elif type == "joblib": |
179 | 211 | import joblib |
|
0 commit comments