Skip to content

Commit 9667fc0

Browse files
Support writing polars.DataFrame to parquet.
1 parent 1a1536d commit 9667fc0

File tree

2 files changed

+68
-6
lines changed

2 files changed

+68
-6
lines changed

pins/drivers.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Sequence
2+
from typing import Literal, Sequence, assert_never
33

44
from .config import PINS_ENV_INSECURE_READ, get_allow_pickle_read
55
from .errors import PinsInsecureReadError
@@ -14,14 +14,38 @@
1414

1515

1616
def _assert_is_pandas_df(x, file_type: str) -> None:
17-
import pandas as pd
17+
df_family = _get_df_family(x)
1818

19-
if not isinstance(x, pd.DataFrame):
19+
if df_family != "pandas":
2020
raise NotImplementedError(
2121
f"Currently only pandas.DataFrame can be saved as type {file_type!r}."
2222
)
2323

2424

25+
def _get_df_family(df) -> Literal["unknown", "pandas", "polars"]:
26+
try:
27+
import polars as pl
28+
except ModuleNotFoundError:
29+
is_polars_df = False
30+
else:
31+
is_polars_df = isinstance(df, pl.DataFrame)
32+
33+
import pandas as pd
34+
35+
is_pandas_df = isinstance(df, pd.DataFrame)
36+
37+
if not is_polars_df and not is_pandas_df:
38+
return "unknown"
39+
if is_polars_df and is_pandas_df: # Hybrid DataFrame type!
40+
return "unknown"
41+
elif is_polars_df:
42+
return "polars"
43+
elif is_pandas_df:
44+
return "pandas"
45+
else:
46+
assert_never(df)
47+
48+
2549
def load_path(meta, path_to_version):
2650
# Check that only a single file name was given
2751
fnames = [meta.file] if isinstance(meta.file, str) else meta.file
@@ -171,9 +195,17 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
171195
)
172196

173197
elif type == "parquet":
174-
_assert_is_pandas_df(obj, file_type=type)
175-
176-
obj.to_parquet(final_name)
198+
df_family = _get_df_family(obj)
199+
if df_family == "polars":
200+
obj.write_parquet(final_name)
201+
elif df_family == "pandas":
202+
obj.to_parquet(final_name)
203+
else:
204+
msg = (
205+
"Currently only pandas.DataFrame and polars.DataFrame can be saved to "
206+
"a parquet file."
207+
)
208+
raise NotImplementedError(msg)
177209

178210
elif type == "joblib":
179211
import joblib

pins/tests/test_drivers.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,36 @@ def test_driver_roundtrip(tmp_path: Path, type_):
7676
assert df.equals(obj)
7777

7878

79+
@pytest.mark.parametrize(
80+
"type_",
81+
[
82+
"parquet",
83+
],
84+
)
85+
def test_driver_polars_roundtrip(tmp_dir2, type_):
86+
import polars as pl
87+
88+
df = pl.DataFrame({"x": [1, 2, 3]})
89+
90+
fname = "some_df"
91+
full_file = f"{fname}.{type_}"
92+
93+
p_obj = tmp_dir2 / fname
94+
res_fname = save_data(df, p_obj, type_)
95+
96+
assert Path(res_fname).name == full_file
97+
98+
meta = MetaRaw(full_file, type_, "my_pin")
99+
pandas_df = load_data(
100+
meta, fsspec.filesystem("file"), tmp_dir2, allow_pickle_read=True
101+
)
102+
103+
# Convert from pandas to polars
104+
obj = pl.DataFrame(pandas_df)
105+
106+
assert df.equals(obj)
107+
108+
79109
@pytest.mark.parametrize(
80110
"type_",
81111
[

0 commit comments

Comments
 (0)