@@ -87,7 +87,11 @@ def to_dataframe(
8787 return df
8888
8989 def load_dataframe (
90- self , df : pandas .DataFrame , extra_data_name : str = "extra_data" , ** kwargs
90+ self ,
91+ df : pandas .DataFrame ,
92+ extra_data_name : str = "extra_data" ,
93+ input_names : tuple [str ] = (),
94+ ** kwargs
9195 ):
9296 """Load the data from a `pandas.DataFrame` into the learner.
9397
@@ -97,11 +101,18 @@ def load_dataframe(
97101 DataFrame with the data to load.
98102 extra_data_name : str, optional
99103 The ``extra_data_name`` used in `to_dataframe`, by default "extra_data".
104+ input_names : tuple[str], optional
105+ The input names of the child learner. By default the input names are
106+ taken from ``df.attrs["inputs"]``, however, metadata is not preserved
107+ when saving/loading a DataFrame to/from a file. In that case, the input
108+ names can be passed explicitly. For example, for a 2D learner, this would
109+ be ``input_names=('x', 'y')``.
100110 **kwargs : dict
101111 Keyword arguments passed to each ``child_learner.load_dataframe(**kwargs)``.
102112 """
103113 self .learner .load_dataframe (df , ** kwargs )
104- for _ , x in df [df .attrs ["inputs" ] + [extra_data_name ]].iterrows ():
114+ keys = df .attrs .get ("inputs" , list (input_names ))
115+ for _ , x in df [keys + [extra_data_name ]].iterrows ():
105116 key = _to_key (x [:- 1 ])
106117 self .extra_data [key ] = x [- 1 ]
107118
0 commit comments