Skip to content

Commit 262901b

Browse files
Support writing polars.DataFrame to parquet.
1 parent 6cbc1b6 commit 262901b

File tree

2 files changed

+68
-6
lines changed

2 files changed

+68
-6
lines changed

pins/drivers.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .meta import Meta
55
from .errors import PinsInsecureReadError
66

7-
from typing import Sequence
7+
from typing import Literal, Sequence, assert_never
88

99
# TODO: move IFileSystem out of boards, to fix circular import
1010
# from .boards import IFileSystem
@@ -15,14 +15,38 @@
1515

1616

1717
def _assert_is_pandas_df(x, file_type: str) -> None:
18-
import pandas as pd
18+
df_family = _get_df_family(x)
1919

20-
if not isinstance(x, pd.DataFrame):
20+
if df_family != "pandas":
2121
raise NotImplementedError(
2222
f"Currently only pandas.DataFrame can be saved as type {file_type!r}."
2323
)
2424

2525

26+
def _get_df_family(df) -> Literal["unknown", "pandas", "polars"]:
27+
try:
28+
import polars as pl
29+
except ModuleNotFoundError:
30+
is_polars_df = False
31+
else:
32+
is_polars_df = isinstance(df, pl.DataFrame)
33+
34+
import pandas as pd
35+
36+
is_pandas_df = isinstance(df, pd.DataFrame)
37+
38+
if not is_polars_df and not is_pandas_df:
39+
return "unknown"
40+
if is_polars_df and is_pandas_df: # Hybrid DataFrame type!
41+
return "unknown"
42+
elif is_polars_df:
43+
return "polars"
44+
elif is_pandas_df:
45+
return "pandas"
46+
else:
47+
assert_never(df)
48+
49+
2650
def load_path(meta, path_to_version):
2751
# Check that only a single file name was given
2852
fnames = [meta.file] if isinstance(meta.file, str) else meta.file
@@ -172,9 +196,17 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
172196
)
173197

174198
elif type == "parquet":
175-
_assert_is_pandas_df(obj, file_type=type)
176-
177-
obj.to_parquet(final_name)
199+
df_family = _get_df_family(obj)
200+
if df_family == "polars":
201+
obj.write_parquet(final_name)
202+
elif df_family == "pandas":
203+
obj.to_parquet(final_name)
204+
else:
205+
msg = (
206+
"Currently only pandas.DataFrame and polars.DataFrame can be saved to "
207+
"a parquet file."
208+
)
209+
raise NotImplementedError(msg)
178210

179211
elif type == "joblib":
180212
import joblib

pins/tests/test_drivers.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,36 @@ def test_driver_roundtrip(tmp_path: Path, type_):
7777
assert df.equals(obj)
7878

7979

80+
@pytest.mark.parametrize(
81+
"type_",
82+
[
83+
"parquet",
84+
],
85+
)
86+
def test_driver_polars_roundtrip(tmp_dir2, type_):
87+
import polars as pl
88+
89+
df = pl.DataFrame({"x": [1, 2, 3]})
90+
91+
fname = "some_df"
92+
full_file = f"{fname}.{type_}"
93+
94+
p_obj = tmp_dir2 / fname
95+
res_fname = save_data(df, p_obj, type_)
96+
97+
assert Path(res_fname).name == full_file
98+
99+
meta = MetaRaw(full_file, type_, "my_pin")
100+
pandas_df = load_data(
101+
meta, fsspec.filesystem("file"), tmp_dir2, allow_pickle_read=True
102+
)
103+
104+
# Convert from pandas to polars
105+
obj = pl.DataFrame(pandas_df)
106+
107+
assert df.equals(obj)
108+
109+
80110
@pytest.mark.parametrize(
81111
"type_",
82112
[

0 commit comments

Comments
 (0)