22
33import sys
44from copy import copy
5- from typing import Any
5+ from typing import TYPE_CHECKING , Any
66
77import cloudpickle
88from sortedcontainers import SortedDict , SortedSet
1515 partial_function_from_dataframe ,
1616)
1717
18+ if TYPE_CHECKING :
19+ from collections .abc import Sequence
20+ from typing import Callable
21+
1822try :
1923 import pandas
2024
@@ -82,12 +86,17 @@ class SequenceLearner(BaseLearner):
8286 the added benefit of having results in the local kernel already.
8387 """
8488
85- def __init__ (self , function , sequence ):
89+ def __init__ (
90+ self ,
91+ function : Callable [[Any ], Any ],
92+ sequence : Sequence [Any ],
93+ ):
8694 self ._original_function = function
8795 self .function = _IgnoreFirstArgument (function )
8896 # prefer range(len(...)) over enumerate to avoid slowdowns
8997 # when passing lazy sequences
90- self ._to_do_indices = SortedSet (range (len (sequence )))
98+ indices = range (len (sequence ))
99+ self ._to_do_indices = SortedSet (indices )
91100 self ._ntotal = len (sequence )
92101 self .sequence = copy (sequence )
93102 self .data = SortedDict ()
@@ -161,6 +170,8 @@ def to_dataframe( # type: ignore[override]
161170 index_name : str = "i" ,
162171 x_name : str = "x" ,
163172 y_name : str = "y" ,
173+ * ,
174+ full_sequence : bool = False ,
164175 ) -> pandas .DataFrame :
165176 """Return the data as a `pandas.DataFrame`.
166177
@@ -178,6 +189,9 @@ def to_dataframe( # type: ignore[override]
178189 Name of the input value, by default "x"
179190 y_name : str, optional
180191 Name of the output value, by default "y"
192+ full_sequence : bool, optional
193+ If True, the returned dataframe will have the full sequence
194+ where the y_name values are pd.NA if not evaluated yet.
181195
182196 Returns
183197 -------
@@ -190,8 +204,16 @@ def to_dataframe( # type: ignore[override]
190204 """
191205 if not with_pandas :
192206 raise ImportError ("pandas is not installed." )
193- indices , ys = zip (* self .data .items ()) if self .data else ([], [])
194- sequence = [self .sequence [i ] for i in indices ]
207+ import pandas as pd
208+
209+ if full_sequence :
210+ indices = list (range (len (self .sequence )))
211+ sequence = list (self .sequence )
212+ ys = [self .data .get (i , pd .NA ) for i in indices ]
213+ else :
214+ indices , ys = zip (* self .data .items ()) if self .data else ([], []) # type: ignore[assignment]
215+ sequence = [self .sequence [i ] for i in indices ]
216+
195217 df = pandas .DataFrame (indices , columns = [index_name ])
196218 df [x_name ] = sequence
197219 df [y_name ] = ys
@@ -209,6 +231,8 @@ def load_dataframe( # type: ignore[override]
209231 index_name : str = "i" ,
210232 x_name : str = "x" ,
211233 y_name : str = "y" ,
234+ * ,
235+ full_sequence : bool = False ,
212236 ):
213237 """Load data from a `pandas.DataFrame`.
214238
@@ -231,10 +255,25 @@ def load_dataframe( # type: ignore[override]
231255 The ``x_name`` used in ``to_dataframe``, by default "x"
232256 y_name : str, optional
233257 The ``y_name`` used in ``to_dataframe``, by default "y"
258+ full_sequence : bool, optional
259+ The ``full_sequence`` used in ``to_dataframe``, by default False
234260 """
261+ if not with_pandas :
262+ raise ImportError ("pandas is not installed." )
263+ import pandas as pd
264+
235265 indices = df [index_name ].values
236266 xs = df [x_name ].values
237- self .tell_many (zip (indices , xs ), df [y_name ].values )
267+ ys = df [y_name ].values
268+
269+ if full_sequence :
270+ evaluated_indices = [i for i , y in enumerate (ys ) if y is not pd .NA ]
271+ xs = xs [evaluated_indices ]
272+ ys = ys [evaluated_indices ]
273+ indices = indices [evaluated_indices ]
274+
275+ self .tell_many (zip (indices , xs ), ys )
276+
238277 if with_default_function_args :
239278 self .function = partial_function_from_dataframe (
240279 self ._original_function , df , function_prefix
0 commit comments