11from pathlib import Path
2- from typing import Literal , Sequence
2+ from typing import Literal , Sequence , TypeAlias
33
44from .config import PINS_ENV_INSECURE_READ , get_allow_pickle_read
55from .errors import PinsInsecureReadError
1111
1212UNSAFE_TYPES = frozenset (["joblib" ])
1313REQUIRES_SINGLE_FILE = frozenset (["csv" , "joblib" , "file" ])
14-
15-
16- def _assert_is_pandas_df (x , file_type : str ) -> None :
17- df_family = _get_df_family (x )
18-
19- if df_family != "pandas" :
20- raise NotImplementedError (
21- f"Currently only pandas.DataFrame can be saved as type { file_type !r} ."
22- )
23-
24-
25- def _get_df_family (df ) -> Literal ["pandas" , "polars" ]:
26- """Return the type of DataFrame, or raise NotImplementedError if we can't decide."""
27- try :
28- import polars as pl
29- except ModuleNotFoundError :
30- is_polars_df = False
31- else :
32- is_polars_df = isinstance (df , pl .DataFrame )
33-
34- import pandas as pd
35-
36- is_pandas_df = isinstance (df , pd .DataFrame )
37-
38- if is_polars_df and is_pandas_df :
39- raise NotImplementedError (
40- "Hybrid DataFrames (simultaneously pandas and polars) are not supported."
41- )
42- elif is_polars_df :
43- return "polars"
44- elif is_pandas_df :
45- return "pandas"
46- raise NotImplementedError (f"Unrecognized DataFrame type: { type (df )} " )
14+ _DFLib : TypeAlias = Literal ["pandas" , "polars" ]
4715
4816
4917def load_path (meta , path_to_version ):
@@ -176,36 +144,31 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
176144 final_name = f"{ fname } { suffix } "
177145
178146 if type == "csv" :
179- _assert_is_pandas_df (obj , file_type = type )
180-
147+ _choose_df_lib (obj , supported_libs = ["pandas" ], file_type = type )
181148 obj .to_csv (final_name , index = False )
182149
183150 elif type == "arrow" :
184151 # NOTE: R pins accepts the type arrow, and saves it as feather.
185152 # we allow reading this type, but raise an error for writing.
186- _assert_is_pandas_df (obj , file_type = type )
187-
153+ _choose_df_lib (obj , supported_libs = ["pandas" ], file_type = type )
188154 obj .to_feather (final_name )
189155
190156 elif type == "feather" :
191- _assert_is_pandas_df (obj , file_type = type )
157+ _choose_df_lib (obj , supported_libs = [ "pandas" ] , file_type = type )
192158
193159 raise NotImplementedError (
194160 'Saving data as type "feather" no longer supported. Use type "arrow" instead.'
195161 )
196162
197163 elif type == "parquet" :
198- df_family = _get_df_family (obj )
199- if df_family == "polars" :
200- obj .write_parquet (final_name )
201- elif df_family == "pandas" :
164+ df_lib = _choose_df_lib (obj , supported_libs = ["pandas" , "polars" ], file_type = type )
165+
166+ if df_lib == "pandas" :
202167 obj .to_parquet (final_name )
168+ elif df_lib == "polars" :
169+ obj .write_parquet (final_name )
203170 else :
204- msg = (
205- "Currently only pandas.DataFrame and polars.DataFrame can be saved to "
206- "a parquet file."
207- )
208- raise NotImplementedError (msg )
171+ raise NotImplementedError
209172
210173 elif type == "joblib" :
211174 import joblib
@@ -233,7 +196,7 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
233196
234197def default_title (obj , name ):
235198 try :
236- _get_df_family (obj )
199+ _choose_df_lib (obj )
237200 except NotImplementedError :
238201 obj_name = type (obj ).__qualname__
239202 return f"{ name } : a pinned { obj_name } object"
@@ -242,3 +205,73 @@ def default_title(obj, name):
242205 # see https://github.com/machow/pins-python/issues/5
243206 shape_str = " x " .join (map (str , obj .shape ))
244207 return f"{ name } : a pinned { shape_str } DataFrame"
208+
209+
210+ def _choose_df_lib (
211+ df ,
212+ * ,
213+ supported_libs : list [_DFLib ] = ["pandas" , "polars" ],
214+ file_type : str | None = None ,
215+ ) -> _DFLib :
216+ """Return the type of DataFrame library used in the given DataFrame.
217+
218+ Args:
219+ df:
220+ The object to check - might not be a DataFrame necessarily.
221+ supported_libs:
222+ The DataFrame libraries to accept for this df.
223+ file_type:
224+ The file type we're trying to save to - used to give more specific error messages.
225+
226+ Raises:
227+ NotImplementedError: If the DataFrame type is not recognized.
228+ """
229+ df_libs : list [_DFLib ] = []
230+
231+ # pandas
232+ import pandas as pd
233+
234+ if isinstance (df , pd .DataFrame ):
235+ df_libs .append ("pandas" )
236+
237+ # polars
238+ try :
239+ import polars as pl
240+ except ModuleNotFoundError :
241+ pass
242+ else :
243+ if isinstance (df , pl .DataFrame ):
244+ df_libs .append ("polars" )
245+
246+ if len (df_libs ) == 1 :
247+ (df_lib ,) = df_libs
248+ elif len (df_libs ) > 1 :
249+ msg = (
250+ f"Hybrid DataFrames are not supported: "
251+ f"should only be one of { supported_libs !r} , "
252+ f"but got an object from multiple libraries { df_libs !r} ."
253+ )
254+ raise NotImplementedError (msg )
255+ else :
256+ raise NotImplementedError (f"Unrecognized DataFrame type: { type (df )} " )
257+
258+ if df_lib not in supported_libs :
259+ if file_type is None :
260+ ftype_clause = "in pins"
261+ else :
262+ ftype_clause = f"for type { file_type !r} "
263+
264+ if len (supported_libs ) == 1 :
265+ msg = (
266+ f"Currently only { supported_libs [0 ]} DataFrames can be saved "
267+ f"{ ftype_clause } . { df_lib } DataFrames are not yet supported."
268+ )
269+ else :
270+ msg = (
271+ f"Currently only DataFrames from the following libraries can be saved "
272+ f"{ ftype_clause } : { supported_libs !r} ."
273+ )
274+
275+ raise NotImplementedError (msg )
276+
277+ return df_lib
0 commit comments