Skip to content

Commit 85bd8c6

Browse files
Write more robust df-library choosing logic.
1 parent 036c7e9 commit 85bd8c6

File tree

1 file changed

+82
-49
lines changed

1 file changed

+82
-49
lines changed

pins/drivers.py

Lines changed: 82 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Literal, Sequence
2+
from typing import Literal, Sequence, TypeAlias
33

44
from .config import PINS_ENV_INSECURE_READ, get_allow_pickle_read
55
from .errors import PinsInsecureReadError
@@ -11,39 +11,7 @@
1111

1212
UNSAFE_TYPES = frozenset(["joblib"])
1313
REQUIRES_SINGLE_FILE = frozenset(["csv", "joblib", "file"])
14-
15-
16-
def _assert_is_pandas_df(x, file_type: str) -> None:
17-
df_family = _get_df_family(x)
18-
19-
if df_family != "pandas":
20-
raise NotImplementedError(
21-
f"Currently only pandas.DataFrame can be saved as type {file_type!r}."
22-
)
23-
24-
25-
def _get_df_family(df) -> Literal["pandas", "polars"]:
26-
"""Return the type of DataFrame, or raise NotImplementedError if we can't decide."""
27-
try:
28-
import polars as pl
29-
except ModuleNotFoundError:
30-
is_polars_df = False
31-
else:
32-
is_polars_df = isinstance(df, pl.DataFrame)
33-
34-
import pandas as pd
35-
36-
is_pandas_df = isinstance(df, pd.DataFrame)
37-
38-
if is_polars_df and is_pandas_df:
39-
raise NotImplementedError(
40-
"Hybrid DataFrames (simultaneously pandas and polars) are not supported."
41-
)
42-
elif is_polars_df:
43-
return "polars"
44-
elif is_pandas_df:
45-
return "pandas"
46-
raise NotImplementedError(f"Unrecognized DataFrame type: {type(df)}")
14+
_DFLib: TypeAlias = Literal["pandas", "polars"]
4715

4816

4917
def load_path(meta, path_to_version):
@@ -176,36 +144,31 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
176144
final_name = f"{fname}{suffix}"
177145

178146
if type == "csv":
179-
_assert_is_pandas_df(obj, file_type=type)
180-
147+
_choose_df_lib(obj, supported_libs=["pandas"], file_type=type)
181148
obj.to_csv(final_name, index=False)
182149

183150
elif type == "arrow":
184151
# NOTE: R pins accepts the type arrow, and saves it as feather.
185152
# we allow reading this type, but raise an error for writing.
186-
_assert_is_pandas_df(obj, file_type=type)
187-
153+
_choose_df_lib(obj, supported_libs=["pandas"], file_type=type)
188154
obj.to_feather(final_name)
189155

190156
elif type == "feather":
191-
_assert_is_pandas_df(obj, file_type=type)
157+
_choose_df_lib(obj, supported_libs=["pandas"], file_type=type)
192158

193159
raise NotImplementedError(
194160
'Saving data as type "feather" no longer supported. Use type "arrow" instead.'
195161
)
196162

197163
elif type == "parquet":
198-
df_family = _get_df_family(obj)
199-
if df_family == "polars":
200-
obj.write_parquet(final_name)
201-
elif df_family == "pandas":
164+
df_lib = _choose_df_lib(obj, supported_libs=["pandas", "polars"], file_type=type)
165+
166+
if df_lib == "pandas":
202167
obj.to_parquet(final_name)
168+
elif df_lib == "polars":
169+
obj.write_parquet(final_name)
203170
else:
204-
msg = (
205-
"Currently only pandas.DataFrame and polars.DataFrame can be saved to "
206-
"a parquet file."
207-
)
208-
raise NotImplementedError(msg)
171+
raise NotImplementedError
209172

210173
elif type == "joblib":
211174
import joblib
@@ -233,7 +196,7 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
233196

234197
def default_title(obj, name):
235198
try:
236-
_get_df_family(obj)
199+
_choose_df_lib(obj)
237200
except NotImplementedError:
238201
obj_name = type(obj).__qualname__
239202
return f"{name}: a pinned {obj_name} object"
@@ -242,3 +205,73 @@ def default_title(obj, name):
242205
# see https://github.com/machow/pins-python/issues/5
243206
shape_str = " x ".join(map(str, obj.shape))
244207
return f"{name}: a pinned {shape_str} DataFrame"
208+
209+
210+
def _choose_df_lib(
211+
df,
212+
*,
213+
supported_libs: list[_DFLib] = ["pandas", "polars"],
214+
file_type: str | None = None,
215+
) -> _DFLib:
216+
"""Return the type of DataFrame library used in the given DataFrame.
217+
218+
Args:
219+
df:
220+
The object to check - might not be a DataFrame necessarily.
221+
supported_libs:
222+
The DataFrame libraries to accept for this df.
223+
file_type:
224+
The file type we're trying to save to - used to give more specific error messages.
225+
226+
Raises:
227+
NotImplementedError: If the DataFrame type is not recognized.
228+
"""
229+
df_libs: list[_DFLib] = []
230+
231+
# pandas
232+
import pandas as pd
233+
234+
if isinstance(df, pd.DataFrame):
235+
df_libs.append("pandas")
236+
237+
# polars
238+
try:
239+
import polars as pl
240+
except ModuleNotFoundError:
241+
pass
242+
else:
243+
if isinstance(df, pl.DataFrame):
244+
df_libs.append("polars")
245+
246+
if len(df_libs) == 1:
247+
(df_lib,) = df_libs
248+
elif len(df_libs) > 1:
249+
msg = (
250+
f"Hybrid DataFrames are not supported: "
251+
f"should only be one of {supported_libs!r}, "
252+
f"but got an object from multiple libraries {df_libs!r}."
253+
)
254+
raise NotImplementedError(msg)
255+
else:
256+
raise NotImplementedError(f"Unrecognized DataFrame type: {type(df)}")
257+
258+
if df_lib not in supported_libs:
259+
if file_type is None:
260+
ftype_clause = "in pins"
261+
else:
262+
ftype_clause = f"for type {file_type!r}"
263+
264+
if len(supported_libs) == 1:
265+
msg = (
266+
f"Currently only {supported_libs[0]} DataFrames can be saved "
267+
f"{ftype_clause}. {df_lib} DataFrames are not yet supported."
268+
)
269+
else:
270+
msg = (
271+
f"Currently only DataFrames from the following libraries can be saved "
272+
f"{ftype_clause}: {supported_libs!r}."
273+
)
274+
275+
raise NotImplementedError(msg)
276+
277+
return df_lib

0 commit comments

Comments
 (0)