diff --git a/mesa_frames/concrete/agentset.py b/mesa_frames/concrete/agentset.py index 81759b19..9fe33323 100644 --- a/mesa_frames/concrete/agentset.py +++ b/mesa_frames/concrete/agentset.py @@ -76,7 +76,7 @@ def step(self): class AgentSetPolars(AgentSetDF, PolarsMixin): """Polars-based implementation of AgentSetDF.""" - _df: pl.DataFrame + _df: pl.LazyFrame _copy_with_method: dict[str, tuple[str, list[str]]] = { "_df": ("clone", []), } @@ -93,19 +93,19 @@ def __init__(self, model: mesa_frames.concrete.model.ModelDF) -> None: """ self._model = model # No definition of schema with unique_id, as it becomes hard to add new agents - self._df = pl.DataFrame() - self._mask = pl.repeat(True, len(self._df), dtype=pl.Boolean, eager=True) + self._df = pl.LazyFrame(schema={"unique_id": pl.Int64}) + self._mask = pl.repeat(True, 0, dtype=pl.Boolean) def add( self, - agents: pl.DataFrame | Sequence[Any] | dict[str, Any], + agents: pl.DataFrame | pl.LazyFrame | Sequence[Any] | dict[str, Any], inplace: bool = True, ) -> Self: """Add agents to the AgentSetPolars. Parameters ---------- - agents : pl.DataFrame | Sequence[Any] | dict[str, Any] + agents : pl.DataFrame | pl.LazyFrame | Sequence[Any] | dict[str, Any] The agents to add. inplace : bool, optional Whether to add the agents in place, by default True. @@ -121,26 +121,25 @@ def add( "AgentSetPolars.add() does not accept AgentSetDF objects. " "Extract the DataFrame with agents.agents.drop('unique_id') first." ) - elif isinstance(agents, pl.DataFrame): + elif isinstance(agents, (pl.DataFrame, pl.LazyFrame)): if "unique_id" in agents.columns: raise ValueError("Dataframe should not have a unique_id column.") - new_agents = agents + new_agents = agents.lazy() if isinstance(agents, pl.DataFrame) else agents elif isinstance(agents, dict): if "unique_id" in agents: raise ValueError("Dictionary should not have a unique_id key.") - new_agents = pl.DataFrame(agents) + new_agents = pl.LazyFrame(agents) else: # Sequence - if len(obj._df) != 0: + if len(obj._df.collect()) != 0: # For non-empty AgentSet, check column count expected_columns = len(obj._df.columns) - 1 # Exclude unique_id if len(agents) != expected_columns: raise ValueError( f"Length of data ({len(agents)}) must match the number of columns in the AgentSet (excluding unique_id): {expected_columns}" ) - new_agents = pl.DataFrame( + new_agents = pl.LazyFrame( [list(agents)], schema=[col for col in obj._df.schema if col != "unique_id"], - orient="row", ) else: # For empty AgentSet, cannot infer schema from sequence @@ -149,19 +148,19 @@ def add( ) new_agents = new_agents.with_columns( - self._generate_unique_ids(len(new_agents)).alias("unique_id") + self._generate_unique_ids(len(new_agents.collect())).alias("unique_id") ) # If self._mask is pl.Expr, then new mask is the same. # If self._mask is pl.Series[bool], then new mask has to be updated. - originally_empty = len(obj._df) == 0 + originally_empty = len(obj._df.collect()) == 0 if isinstance(obj._mask, pl.Series) and not originally_empty: - original_active_indices = obj._df.filter(obj._mask)["unique_id"] + original_active_indices = obj._df.filter(obj._mask).collect()["unique_id"] obj._df = pl.concat([obj._df, new_agents], how="diagonal_relaxed") if isinstance(obj._mask, pl.Series) and not originally_empty: - obj._update_mask(original_active_indices, new_agents["unique_id"]) + obj._update_mask(original_active_indices, new_agents.collect()["unique_id"]) return obj @@ -175,12 +174,14 @@ def contains( self, agents: PolarsIdsLike, ) -> bool | pl.Series: + # Need to collect for containment check + agent_ids = self._df.collect()["unique_id"] if isinstance(agents, pl.Series): - return agents.is_in(self._df["unique_id"]) + return agents.is_in(agent_ids) elif isinstance(agents, Collection) and not isinstance(agents, str): - return pl.Series(agents, dtype=pl.UInt64).is_in(self._df["unique_id"]) + return pl.Series(agents, dtype=pl.UInt64).is_in(agent_ids) else: - return agents in self._df["unique_id"] + return agents in agent_ids def get( self, @@ -190,11 +191,11 @@ def get( masked_df = self._get_masked_df(mask) if attr_names is None: # Return all columns except unique_id - return masked_df.select(pl.exclude("unique_id")) - attr_names = self.df.select(attr_names).columns.copy() + return masked_df.select(pl.exclude("unique_id")).collect() + attr_names = self._df.select(attr_names).columns.copy() if attr_names else [] if not attr_names: - return masked_df - masked_df = masked_df.select(attr_names) + return masked_df.collect() + masked_df = masked_df.select(attr_names).collect() if masked_df.shape[1] == 1: return masked_df[masked_df.columns[0]] return masked_df @@ -210,19 +211,29 @@ def set( masked_df = obj._get_masked_df(mask) if not attr_names: - attr_names = masked_df.columns + attr_names = masked_df.collect().columns attr_names.remove("unique_id") def process_single_attr( - masked_df: pl.DataFrame, attr_name: str, values: Any - ) -> pl.DataFrame: - if isinstance(values, pl.DataFrame): - values_series = values.to_series() - elif isinstance(values, (pl.Expr, pl.Series, Collection)): - values_series = pl.Series(values) + masked_df: pl.LazyFrame, attr_name: str, values: Any + ) -> pl.LazyFrame: + if isinstance(values, (pl.DataFrame, pl.LazyFrame)): + values_series = ( + values.collect() if isinstance(values, pl.LazyFrame) else values + ) + return masked_df.with_columns( + values_series.to_series().alias(attr_name) + ) + elif isinstance(values, pl.Expr): + return masked_df.with_columns(values.alias(attr_name)) + elif isinstance(values, pl.Series): + return masked_df.with_columns(values.alias(attr_name)) else: - values_series = pl.repeat(values, len(masked_df)) - return masked_df.with_columns(values_series.alias(attr_name)) + if isinstance(values, Collection): + values = pl.Series(values) + else: + values = pl.repeat(values, masked_df.collect().height) + return masked_df.with_columns(values.alias(attr_name)) if isinstance(attr_names, str) and values is not None: masked_df = process_single_attr(masked_df, attr_names, values) @@ -241,10 +252,11 @@ def process_single_attr( "attr_names must be a string, a collection of string or a dictionary with columns as keys and values." ) unique_id_column = None + unique_id_column = None if "unique_id" not in obj._df: - unique_id_column = self._generate_unique_ids(len(masked_df)).alias( - "unique_id" - ) + unique_id_column = self._generate_unique_ids( + len(masked_df.collect()) + ).alias("unique_id") obj._df = obj._df.with_columns(unique_id_column) masked_df = masked_df.with_columns(unique_id_column) b_mask = obj._get_bool_mask(mask) @@ -252,7 +264,7 @@ def process_single_attr( original_index = obj._df.select("unique_id") obj._df = pl.concat([non_masked_df, masked_df], how="diagonal_relaxed") obj._df = original_index.join(obj._df, on="unique_id", how="left") - obj._update_mask(original_index["unique_id"], unique_id_column) + obj._update_mask(original_index.collect()["unique_id"], unique_id_column) return obj def select( @@ -268,9 +280,9 @@ def select( if filter_func: mask = mask & filter_func(obj) if n is not None: - mask = (obj._df["unique_id"]).is_in( - obj._df.filter(mask).sample(n)["unique_id"] - ) + # Need to collect for sampling + sample_ids = obj._df.filter(mask).collect().sample(n)["unique_id"] + mask = (obj._df.collect()["unique_id"]).is_in(sample_ids) if negate: mask = mask.not_() obj._mask = mask @@ -278,10 +290,15 @@ def select( def shuffle(self, inplace: bool = True) -> Self: obj = self._get_obj(inplace) - obj._df = obj._df.sample( - fraction=1, - shuffle=True, - seed=obj.random.integers(np.iinfo(np.int32).max), + # Collect to perform shuffle, then convert back to LazyFrame + obj._df = ( + obj._df.collect() + .sample( + fraction=1, + shuffle=True, + seed=obj.random.integers(np.iinfo(np.int32).max), + ) + .lazy() ) return obj @@ -308,8 +325,8 @@ def _concatenate_agentsets( original_masked_index: pl.Series | None = None, ) -> Self: if not duplicates_allowed: - indices_list = [self._df["unique_id"]] + [ - agentset._df["unique_id"] for agentset in agentsets + indices_list = [self._df.collect()["unique_id"]] + [ + agentset._df.collect()["unique_id"] for agentset in agentsets ] all_indices = pl.concat(indices_list) if all_indices.is_duplicated().any(): @@ -321,23 +338,29 @@ def _concatenate_agentsets( max_length = max(len(agentset) for agentset in agentsets) for agentset in agentsets: if len(agentset) == max_length: - original_index = agentset._df["unique_id"] + original_index = agentset._df.collect()["unique_id"] final_dfs = [self._df] - final_active_indices = [self._df["unique_id"]] - final_indices = self._df["unique_id"].clone() + final_active_indices = [self._df.filter(self._mask).collect()["unique_id"]] + final_indices = self._df.collect()["unique_id"].clone() for obj in iter(agentsets): # Remove agents that are already in the final DataFrame final_dfs.append( obj._df.filter(pl.col("unique_id").is_in(final_indices).not_()) ) # Add the indices of the active agents of current AgentSet - final_active_indices.append(obj._df.filter(obj._mask)["unique_id"]) + final_active_indices.append( + obj._df.filter(obj._mask).collect()["unique_id"] + ) # Update the indices of the agents in the final DataFrame final_indices = pl.concat( - [final_indices, final_dfs[-1]["unique_id"]], how="vertical" + [ + final_indices, + final_dfs[-1].collect()["unique_id"], + ], + how="vertical", ) # Left-join original index with concatenated dfs to keep original ids order - final_df = original_index.to_frame().join( + final_df = pl.LazyFrame({"unique_id": original_index}).join( pl.concat(final_dfs, how="diagonal_relaxed"), on="unique_id", how="left" ) # @@ -346,15 +369,15 @@ def _concatenate_agentsets( else: final_df = pl.concat([obj._df for obj in agentsets], how="diagonal_relaxed") final_active_index = pl.concat( - [obj._df.filter(obj._mask)["unique_id"] for obj in agentsets] + [obj._df.filter(obj._mask).collect()["unique_id"] for obj in agentsets] ) - final_mask = final_df["unique_id"].is_in(final_active_index) + final_mask = final_df.collect()["unique_id"].is_in(final_active_index) self._df = final_df self._mask = final_mask # If some ids were removed in the do-method, we need to remove them also from final_df if not isinstance(original_masked_index, type(None)): ids_to_remove = original_masked_index.filter( - original_masked_index.is_in(self._df["unique_id"]).not_() + original_masked_index.is_in(self._df.collect()["unique_id"]).not_() ) if not ids_to_remove.is_empty(): self.remove(ids_to_remove, inplace=True) @@ -368,10 +391,10 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: if ( isinstance(mask, pl.Series) and mask.dtype == pl.Boolean - and len(mask) == len(self._df) + and len(mask) == len(self._df.collect()) ): return mask - return self._df["unique_id"].is_in(mask) + return self._df.collect()["unique_id"].is_in(mask) if isinstance(mask, pl.Expr): return mask @@ -387,7 +410,7 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: "DataFrame must have a 'unique_id' column or a single boolean column." ) elif mask is None or mask == "all": - return pl.repeat(True, len(self._df)) + return pl.repeat(True, len(self._df.collect())) elif mask == "active": return self._mask elif isinstance(mask, Collection): @@ -398,23 +421,28 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: def _get_masked_df( self, mask: AgentPolarsMask = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: if (isinstance(mask, pl.Series) and mask.dtype == pl.Boolean) or isinstance( mask, pl.Expr ): return self._df.filter(mask) - elif isinstance(mask, pl.DataFrame): - if not mask["unique_id"].is_in(self._df["unique_id"]).all(): + elif isinstance(mask, (pl.DataFrame, pl.LazyFrame)): + mask_df = mask.collect() if isinstance(mask, pl.LazyFrame) else mask + agents_ids = self._df.collect()["unique_id"] + if not mask_df["unique_id"].is_in(agents_ids).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) - return mask.select("unique_id").join(self._df, on="unique_id", how="left") + return pl.LazyFrame({"unique_id": mask_df["unique_id"]}).join( + self._df, on="unique_id", how="left" + ) elif isinstance(mask, pl.Series): - if not mask.is_in(self._df["unique_id"]).all(): + agents_ids = self._df.collect()["unique_id"] + if not mask.is_in(agents_ids).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) - mask_df = mask.to_frame("unique_id") + mask_df = pl.LazyFrame({"unique_id": mask}) return mask_df.join(self._df, on="unique_id", how="left") elif mask is None or mask == "all": return self._df @@ -425,11 +453,12 @@ def _get_masked_df( mask_series = pl.Series(mask, dtype=pl.UInt64) else: mask_series = pl.Series([mask], dtype=pl.UInt64) - if not mask_series.is_in(self._df["unique_id"]).all(): + agents_ids = self._df.collect()["unique_id"] + if not mask_series.is_in(agents_ids).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) - mask_df = mask_series.to_frame("unique_id") + mask_df = pl.LazyFrame({"unique_id": mask_series}) return mask_df.join(self._df, on="unique_id", how="left") @overload @@ -438,14 +467,19 @@ def _get_obj_copy(self, obj: pl.Series) -> pl.Series: ... @overload def _get_obj_copy(self, obj: pl.DataFrame) -> pl.DataFrame: ... - def _get_obj_copy(self, obj: pl.Series | pl.DataFrame) -> pl.Series | pl.DataFrame: + @overload + def _get_obj_copy(self, obj: pl.LazyFrame) -> pl.LazyFrame: ... + + def _get_obj_copy( + self, obj: pl.Series | pl.DataFrame | pl.LazyFrame + ) -> pl.Series | pl.DataFrame | pl.LazyFrame: return obj.clone() def _discard(self, ids: PolarsIdsLike) -> Self: mask = self._get_bool_mask(ids) if isinstance(self._mask, pl.Series): - original_active_indices = self._df.filter(self._mask)["unique_id"] + original_active_indices = self._df.filter(self._mask).collect()["unique_id"] self._df = self._df.filter(mask.not_()) @@ -457,16 +491,17 @@ def _discard(self, ids: PolarsIdsLike) -> Self: def _update_mask( self, original_active_indices: pl.Series, new_indices: pl.Series | None = None ) -> None: + agent_ids = self._df.collect()["unique_id"] if new_indices is not None: - self._mask = self._df["unique_id"].is_in( - original_active_indices - ) | self._df["unique_id"].is_in(new_indices) + self._mask = agent_ids.is_in(original_active_indices) | agent_ids.is_in( + new_indices + ) else: - self._mask = self._df["unique_id"].is_in(original_active_indices) + self._mask = agent_ids.is_in(original_active_indices) def __getattr__(self, key: str) -> pl.Series: super().__getattr__(key) - return self._df[key] + return self._df.collect()[key] def _generate_unique_ids(self, n: int) -> pl.Series: return pl.Series( @@ -510,26 +545,28 @@ def __getitem__( return attr def __iter__(self) -> Iterator[dict[str, Any]]: - return iter(self._df.iter_rows(named=True)) + return iter(self._df.collect().iter_rows(named=True)) def __len__(self) -> int: - return len(self._df) + return len(self._df.collect()) def __reversed__(self) -> Iterator: - return reversed(iter(self._df.iter_rows(named=True))) + return reversed(iter(self._df.collect().iter_rows(named=True))) @property - def df(self) -> pl.DataFrame: + def df(self) -> pl.LazyFrame: return self._df @df.setter - def df(self, agents: pl.DataFrame) -> None: - if "unique_id" not in agents.columns: + def df(self, agents: pl.DataFrame | pl.LazyFrame) -> None: + if "unique_id" not in ( + agents.columns if isinstance(agents, pl.LazyFrame) else agents.columns + ): raise KeyError("DataFrame must have a unique_id column.") - self._df = agents + self._df = agents.lazy() if isinstance(agents, pl.DataFrame) else agents @property - def active_agents(self) -> pl.DataFrame: + def active_agents(self) -> pl.LazyFrame: return self.df.filter(self._mask) @active_agents.setter @@ -537,13 +574,13 @@ def active_agents(self, mask: AgentPolarsMask) -> None: self.select(mask=mask, inplace=True) @property - def inactive_agents(self) -> pl.DataFrame: + def inactive_agents(self) -> pl.LazyFrame: return self.df.filter(~self._mask) @property def index(self) -> pl.Series: - return self._df["unique_id"] + return self._df.collect()["unique_id"] @property - def pos(self) -> pl.DataFrame: + def pos(self) -> pl.LazyFrame: return super().pos diff --git a/mesa_frames/concrete/mixin.py b/mesa_frames/concrete/mixin.py index eba00ae6..974d293e 100644 --- a/mesa_frames/concrete/mixin.py +++ b/mesa_frames/concrete/mixin.py @@ -9,11 +9,11 @@ Classes: PolarsMixin(DataFrameMixin): A Polars-based implementation of DataFrame operations. This class provides - methods for manipulating and analyzing data stored in Polars DataFrames, + methods for manipulating and analyzing data stored in Polars LazyFrames, tailored for use in mesa-frames components like AgentSetPolars and GridPolars. The PolarsMixin class is designed to be used as a mixin with other mesa-frames -classes, providing them with Polars-specific DataFrame functionality. It implements +classes, providing them with Polars-specific LazyFrame functionality. It implements the abstract methods defined in the DataFrameMixin, ensuring consistent DataFrame operations across the mesa-frames package. @@ -26,7 +26,7 @@ class AgentSetPolars(AgentSetDF, PolarsMixin): def __init__(self, model): super().__init__(model) - self.agents = pl.DataFrame() # Initialize empty DataFrame + self.agents = pl.LazyFrame() # Initialize empty LazyFrame def some_method(self): # Use Polars operations provided by the mixin @@ -34,8 +34,8 @@ def some_method(self): # ... further processing ... Features: - - High-performance DataFrame operations using Polars - - Support for both eager and lazy evaluation + - High-performance LazyFrame operations using Polars + - Support for lazy evaluation with improved query optimization - Efficient memory usage and fast computation - Integration with Polars' query optimization capabilities @@ -55,7 +55,7 @@ def some_method(self): class PolarsMixin(DataFrameMixin): - """Polars-specific implementation of DataFrame operations.""" + """Polars-specific implementation of DataFrame operations using LazyFrames.""" # TODO: complete with other dtypes _dtypes_mapping: dict[str, Any] = { @@ -70,7 +70,7 @@ def _df_add( other: pl.DataFrame | Collection[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -81,13 +81,15 @@ def _df_add( def _df_all( self, - df: pl.DataFrame, + df: pl.LazyFrame, name: str = "all", axis: Literal["index", "columns"] = "columns", - ) -> pl.Series: + ) -> pl.Expr: if axis == "index": - return pl.Series(name, df.select(pl.col("*").all()).row(0)) - return df.with_columns(pl.all_horizontal("*").alias(name))[name] + # Return an expression that will evaluate to all values across index + return pl.all(pl.col("*")).alias(name) + # Return an expression for all values across columns + return pl.all_horizontal(pl.col("*")).alias(name) def _df_and( self, @@ -95,7 +97,7 @@ def _df_and( other: pl.DataFrame | Collection[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -104,73 +106,85 @@ def _df_and( index_cols=index_cols, ) - def _df_column_names(self, df: pl.DataFrame) -> list[str]: + def _df_column_names(self, df: pl.LazyFrame) -> list[str]: + # This operation requires schema inspection which is available on LazyFrame return df.columns def _df_combine_first( self, - original_df: pl.DataFrame, - new_df: pl.DataFrame, + original_df: pl.LazyFrame, + new_df: pl.LazyFrame, index_cols: str | list[str], - ) -> pl.DataFrame: - original_df = original_df.with_columns(_index=pl.int_range(0, len(original_df))) + ) -> pl.LazyFrame: + # Create a sequential index using with_row_count instead of int_range + original_df = original_df.with_row_count("_index") common_cols = set(original_df.columns) & set(new_df.columns) merged_df = original_df.join(new_df, on=index_cols, how="full", suffix="_right") - merged_df = ( - merged_df.with_columns( - pl.coalesce(pl.col(col), pl.col(f"{col}_right")).alias(col) - for col in common_cols - ) - .select(pl.exclude("^.*_right$")) - .sort("_index") - .drop("_index") - ) - return merged_df + + # Use expressions to coalesce values + coalesce_exprs = [ + pl.coalesce(pl.col(col), pl.col(f"{col}_right")).alias(col) + for col in common_cols + if col in merged_df.columns and f"{col}_right" in merged_df.columns + ] + + # Apply coalesce expressions and drop right columns + merged_df = merged_df.with_columns(coalesce_exprs) + right_cols = [col for col in merged_df.columns if col.endswith("_right")] + merged_df = merged_df.drop(right_cols) + + # Sort by index and drop index column + return merged_df.sort("_index").drop("_index") @overload def _df_concat( self, - objs: Collection[pl.DataFrame], + objs: Collection[pl.LazyFrame], how: Literal["horizontal"] | Literal["vertical"] = "vertical", ignore_index: bool = False, index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: ... + ) -> pl.LazyFrame: ... @overload def _df_concat( self, - objs: Collection[pl.Series], + objs: Collection[pl.Expr], how: Literal["vertical"] = "vertical", ignore_index: bool = False, index_cols: str | list[str] | None = None, - ) -> pl.Series: ... + ) -> pl.Expr: ... @overload def _df_concat( self, - objs: Collection[pl.Series], + objs: Collection[pl.Expr], how: Literal["horizontal"] = "horizontal", ignore_index: bool = False, index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: ... + ) -> pl.LazyFrame: ... def _df_concat( self, - objs: Collection[pl.DataFrame] | Collection[pl.Series], + objs: Collection[pl.LazyFrame] | Collection[pl.Expr], how: Literal["horizontal"] | Literal["vertical"] = "vertical", ignore_index: bool = False, index_cols: str | None = None, - ) -> pl.Series | pl.DataFrame: - if isinstance(objs[0], pl.DataFrame) and how == "vertical": + ) -> pl.LazyFrame | pl.Expr: + if isinstance(next(iter(objs), None), pl.LazyFrame) and how == "vertical": how = "diagonal_relaxed" - if isinstance(objs[0], pl.Series) and how == "horizontal": - obj = pl.DataFrame().hstack(objs, in_place=True) + + if isinstance(next(iter(objs), None), pl.Expr) and how == "horizontal": + # Convert expressions to LazyFrames for horizontal concat + obj = pl.LazyFrame().with_columns(list(objs)) else: + # Use concat on LazyFrames directly obj = pl.concat(objs, how=how) - if isinstance(obj, pl.DataFrame) and how == "horizontal" and ignore_index: - obj = obj.rename( - {c: str(i) for c, i in zip(obj.columns, range(len(obj.columns)))} - ) + + if isinstance(obj, pl.LazyFrame) and how == "horizontal" and ignore_index: + # Rename columns if ignore_index is True + rename_dict = {c: str(i) for i, c in enumerate(obj.columns)} + obj = obj.rename(rename_dict) + return obj def _df_constructor( @@ -184,25 +198,29 @@ def _df_constructor( if dtypes is not None: dtypes = {k: self._dtypes_mapping.get(v, v) for k, v in dtypes.items()} - df = pl.DataFrame( - data=data, schema=columns, schema_overrides=dtypes, orient="row" - ) + # Create LazyFrame directly + df = pl.LazyFrame(data=data, schema=columns, schema_overrides=dtypes) + if index is not None: if index_cols is not None: if isinstance(index_cols, str): index_cols = [index_cols] - index_df = pl.DataFrame(index, index_cols) + index_df = pl.LazyFrame({col: index for col in index_cols}) else: - index_df = pl.DataFrame(index) - if len(df) != len(index_df) and len(df) == 1: - df = index_df.join(df, how="cross") + index_df = pl.LazyFrame({"index": index}) + + if len(df.schema) == 0: + # Empty LazyFrame case + df = index_df else: - df = index_df.hstack(df) + # Use cross join for single row df or regular join otherwise + df = index_df.join(df, how="cross") + return df def _df_contains( self, - df: pl.DataFrame, + df: pl.LazyFrame, column: str, values: Collection[Any], ) -> pl.Series: @@ -214,7 +232,7 @@ def _df_div( other: pl.DataFrame | Collection[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -225,40 +243,28 @@ def _df_div( def _df_drop_columns( self, - df: pl.DataFrame, + df: pl.LazyFrame, columns: str | list[str], - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return df.drop(columns) def _df_drop_duplicates( self, - df: pl.DataFrame, + df: pl.LazyFrame, subset: str | list[str] | None = None, keep: Literal["first", "last", False] = "first", - ) -> pl.DataFrame: + ) -> pl.LazyFrame: # If subset is None, use all columns if subset is None: subset = df.columns - original_col_order = df.columns + if keep == "first": - return ( - df.group_by(subset, maintain_order=True) - .first() - .select(original_col_order) - ) + return df.unique(subset=subset, keep="first") elif keep == "last": - return ( - df.group_by(subset, maintain_order=True) - .last() - .select(original_col_order) - ) + return df.unique(subset=subset, keep="last") else: - return ( - df.with_columns(pl.len().over(subset)) - .filter(pl.col("len") < 2) - .drop("len") - .select(original_col_order) - ) + # For keep=False, drop all duplicates + return df.filter(~pl.col(subset).is_duplicated()) def _df_ge( self, @@ -266,7 +272,7 @@ def _df_ge( other: pl.DataFrame | Collection[float | int], axis: Literal["index", "columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -277,48 +283,49 @@ def _df_ge( def _df_get_bool_mask( self, - df: pl.DataFrame, + df: pl.LazyFrame, index_cols: str | list[str] | None = None, mask: PolarsMask = None, negate: bool = False, - ) -> pl.Series | pl.Expr: - def bool_mask_from_series(mask: pl.Series) -> pl.Series: - if ( - isinstance(mask, pl.Series) - and mask.dtype == pl.Boolean - and len(mask) == len(df) - ): - return mask - assert isinstance(index_cols, str) - return df[index_cols].is_in(mask) - - def bool_mask_from_df(mask: pl.DataFrame) -> pl.Series: - assert index_cols, list[str] - mask = mask[index_cols].unique() - mask = mask.with_columns(in_it=True) - return df.join(mask, on=index_cols, how="left")["in_it"].fill_null(False) + ) -> pl.Expr: + def bool_mask_from_expr(mask: pl.Expr) -> pl.Expr: + return mask + + def bool_mask_from_lazyframe(mask: pl.LazyFrame) -> pl.Expr: + if index_cols is None: + raise ValueError( + "index_cols must be provided when using LazyFrame mask" + ) - if isinstance(mask, pl.Expr): - result = mask - elif isinstance(mask, pl.Series): - result = bool_mask_from_series(mask) - elif isinstance(mask, pl.DataFrame): - if index_cols in mask.columns: - result = bool_mask_from_series(mask[index_cols]) - elif all(col in mask.columns for col in index_cols): - result = bool_mask_from_df(mask[index_cols]) - elif len(mask.columns) == 1 and mask.dtypes[0] == pl.Boolean: - result = bool_mask_from_series(mask[mask.columns[0]]) + if isinstance(index_cols, str): + return pl.col(index_cols).is_in(mask.select(index_cols)) + else: + # For multiple index columns, create an expression to check if in the mask + join_cols = [pl.col(col) for col in index_cols] + return pl.struct(join_cols).is_in(mask.select(index_cols)) + + def bool_mask_from_values(values) -> pl.Expr: + if index_cols is None: + raise ValueError("index_cols must be provided when using value mask") + + if isinstance(index_cols, str): + return pl.col(index_cols).is_in(values) else: - raise KeyError( - f"Mask must have {index_cols} column(s) or a single boolean column." + # This is simplified and may need adjustment for multi-column case + raise NotImplementedError( + "Multi-column masking with raw values not implemented" ) + + if isinstance(mask, pl.Expr): + result = bool_mask_from_expr(mask) + elif isinstance(mask, pl.LazyFrame): + result = bool_mask_from_lazyframe(mask) elif mask is None or mask == "all": - result = pl.Series([True] * len(df)) + result = pl.lit(True) elif isinstance(mask, Collection): - result = bool_mask_from_series(pl.Series(mask)) + result = bool_mask_from_values(mask) else: - result = bool_mask_from_series(pl.Series([mask])) + result = bool_mask_from_values([mask]) if negate: result = ~result @@ -327,21 +334,21 @@ def bool_mask_from_df(mask: pl.DataFrame) -> pl.Series: def _df_get_masked_df( self, - df: pl.DataFrame, + df: pl.LazyFrame, index_cols: str | list[str] | None = None, mask: PolarsMask | None = None, columns: list[str] | None = None, negate: bool = False, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: b_mask = self._df_get_bool_mask(df, index_cols, mask, negate=negate) if columns: - return df.filter(b_mask)[columns] + return df.filter(b_mask).select(columns) return df.filter(b_mask) def _df_groupby_cumcount( - self, df: pl.DataFrame, by: str | list[str], name="cum_count" - ) -> pl.Series: - return df.with_columns(pl.cum_count(by).over(by).alias(name))[name] + self, df: pl.LazyFrame, by: str | list[str], name="cum_count" + ) -> pl.Expr: + return pl.cumcount().over(by).alias(name) @overload def _df_index(self, df: pl.DataFrame, index_col: str) -> pl.Series: ... @@ -356,13 +363,13 @@ def _df_index( ) -> pl.Series | pl.DataFrame: return df[index_col] - def _df_iterator(self, df: pl.DataFrame) -> Iterator[dict[str, Any]]: - return iter(df.iter_rows(named=True)) + def _df_iterator(self, df: pl.LazyFrame) -> Iterator[dict[str, Any]]: + return iter(df.collect().iter_rows(named=True)) def _df_join( self, - left: pl.DataFrame, - right: pl.DataFrame, + left: pl.LazyFrame, + right: pl.LazyFrame, index_cols: str | list[str] | None = None, on: str | list[str] | None = None, left_on: str | list[str] | None = None, @@ -375,7 +382,7 @@ def _df_join( | Literal["cross"] ) = "left", suffix="_right", - ) -> pl.DataFrame: + ) -> pl.LazyFrame: if how == "outer": how = "full" if how == "right": @@ -392,7 +399,7 @@ def _df_lt( other: pl.DataFrame | Collection[float | int], axis: Literal["index", "columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -407,7 +414,7 @@ def _df_mod( other: pl.DataFrame | Collection[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -422,7 +429,7 @@ def _df_mul( other: pl.DataFrame | Collection[float | int], axis: Literal["index", "columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -434,27 +441,27 @@ def _df_mul( @overload def _df_norm( self, - df: pl.DataFrame, + df: pl.LazyFrame, srs_name: str = "norm", include_cols: Literal[False] = False, - ) -> pl.Series: ... + ) -> pl.Expr: ... @overload def _df_norm( self, - df: pl.Series, + df: pl.Expr, srs_name: str = "norm", include_cols: Literal[True] = True, - ) -> pl.DataFrame: ... + ) -> pl.LazyFrame: ... def _df_norm( self, - df: pl.DataFrame, + df: pl.LazyFrame, srs_name: str = "norm", include_cols: bool = False, - ) -> pl.Series | pl.DataFrame: + ) -> pl.Expr | pl.LazyFrame: srs = ( - df.with_columns(pl.col("*").pow(2)).sum_horizontal().sqrt().rename(srs_name) + df.with_columns(pl.col("*").pow(2)).sum_horizontal().sqrt().alias(srs_name) ) if include_cols: return df.with_columns(srs) @@ -467,44 +474,52 @@ def _df_operation( operation: Callable[[pl.Expr, pl.Expr], pl.Expr], axis: Literal["index", "columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: - if isinstance(other, pl.DataFrame): + ) -> pl.LazyFrame: + if isinstance(other, pl.LazyFrame): if index_cols is not None: + # Join with the other LazyFrame op_df = df.join(other, how="left", on=index_cols, suffix="_op") else: - assert len(df) == len(other), ( - "DataFrames must have the same length if index_cols is not specified" - ) - index_cols = [] - other = other.rename(lambda col: col + "_op") + # Without index cols, assume matching order and do a horizontal concat + other = other.rename({col: f"{col}_op" for col in other.columns}) op_df = pl.concat([df, other], how="horizontal") - return op_df.with_columns( - operation(pl.col(col), pl.col(f"{col}_op")).alias(col) - for col in df.columns - if col not in index_cols - ).select(df.columns) - elif isinstance( - other, (Sequence, pl.Series) - ): # Currently, pl.Series is not a Sequence + + # Apply the operation to matching columns + expr_list = [] + for col in df.columns: + if col not in (index_cols or []): + if f"{col}_op" in op_df.columns: + expr_list.append( + operation(pl.col(col), pl.col(f"{col}_op")).alias(col) + ) + else: + expr_list.append(pl.col(col)) + else: + expr_list.append(pl.col(col)) + + return op_df.with_columns(expr_list).select(df.columns) + elif isinstance(other, (Sequence, pl.Series)): if axis == "index": - assert len(df) == len(other), ( - "Sequence must have the same length as df if axis is 'index'" - ) - other_series = pl.Series("operand", other) - return df.with_columns( - operation(pl.col(col), other_series).alias(col) - for col in df.columns - ) + # Apply operation row-wise + if isinstance(other, pl.Series): + # Convert Series to an expression + other_expr = pl.lit(other.to_list()) + else: + other_expr = pl.lit(list(other)) + + expr_list = [ + operation(pl.col(col), other_expr).alias(col) for col in df.columns + ] + return df.with_columns(expr_list) else: - assert len(df.columns) == len(other), ( - "Sequence must have the same length as df.columns if axis is 'columns'" - ) - return df.with_columns( + # Apply operation column-wise + expr_list = [ operation(pl.col(col), pl.lit(other[i])).alias(col) for i, col in enumerate(df.columns) - ) + ] + return df.with_columns(expr_list) else: - raise ValueError("other must be a DataFrame or a Sequence") + raise ValueError("other must be a LazyFrame or a Sequence") def _df_or( self, @@ -512,7 +527,7 @@ def _df_or( other: pl.DataFrame | Collection[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return self._df_operation( df=df, other=other, @@ -527,13 +542,13 @@ def _df_reindex( other: Sequence[Hashable] | pl.DataFrame | pl.Series, new_index_cols: str | list[str], original_index_cols: str | list[str] | None = None, - ) -> pl.DataFrame: - # If other is a DataFrame, extract the index columns - if isinstance(other, pl.DataFrame): + ) -> pl.LazyFrame: + # If other is a LazyFrame, extract the index columns + if isinstance(other, pl.LazyFrame): other = other.select(new_index_cols) else: - # If other is a sequence, create a DataFrame with it - other = pl.Series(name=new_index_cols, values=other).to_frame() + # If other is a sequence, create a LazyFrame with it + other = pl.LazyFrame({new_index_cols: other}) # Perform a left join to reindex if original_index_cols is None: @@ -544,16 +559,16 @@ def _df_reindex( return result def _df_rename_columns( - self, df: pl.DataFrame, old_columns: list[str], new_columns: list[str] - ) -> pl.DataFrame: + self, df: pl.LazyFrame, old_columns: list[str], new_columns: list[str] + ) -> pl.LazyFrame: return df.rename(dict(zip(old_columns, new_columns))) def _df_reset_index( self, - df: pl.DataFrame, + df: pl.LazyFrame, index_cols: str | list[str] | None = None, drop: bool = False, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: if drop and index_cols is not None: return df.drop(index_cols) else: @@ -561,13 +576,13 @@ def _df_reset_index( def _df_sample( self, - df: pl.DataFrame, + df: pl.LazyFrame, n: int | None = None, frac: float | None = None, with_replacement: bool = False, shuffle: bool = False, seed: int | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: return df.sample( n=n, fraction=frac, @@ -588,22 +603,22 @@ def _df_set_index( def _df_with_columns( self, - original_df: pl.DataFrame, - data: Sequence | pl.DataFrame | Sequence[Sequence] | dict[str | Any] | Any, + original_df: pl.LazyFrame, + data: Sequence | pl.LazyFrame | Sequence[Sequence] | dict[str | Any] | Any, new_columns: str | list[str] | None = None, - ) -> pl.DataFrame: + ) -> pl.LazyFrame: if ( (isinstance(data, Sequence) and isinstance(data[0], Sequence)) or isinstance( - data, pl.DataFrame - ) # Currently, pl.DataFrame is not a Sequence + data, pl.LazyFrame + ) # Currently, pl.LazyFrame is not a Sequence or ( isinstance(data, dict) and isinstance(data[list(data.keys())[0]], Sequence) ) ): # This means that data is a Sequence of Sequences (rows) - data = pl.DataFrame(data, new_columns, orient="row") + data = pl.LazyFrame(data, new_columns) original_df = original_df.select(pl.exclude(data.columns)) return original_df.hstack(data) if not isinstance(data, dict): @@ -629,7 +644,7 @@ def _srs_contains( self, srs: Collection[Any], values: Any | Sequence[Any], - ) -> pl.Series: + ) -> pl.Expr: if not isinstance(values, Collection): values = [values] return pl.Series(values).is_in(pl.Series(srs)) @@ -641,12 +656,12 @@ def _srs_range( end: int, step: int = 1, ) -> pl.Series: - return pl.arange(start=start, end=end, step=step, eager=True).rename(name) + return pl.arange(start=start, end=end, step=step, eager=True).alias(name) def _srs_to_df( self, srs: pl.Series, index: pl.Series | None = None - ) -> pl.DataFrame: - df = srs.to_frame() + ) -> pl.LazyFrame: + df = srs.to_frame().lazy() if index: return df.with_columns({index.name: index}) return df diff --git a/tests/test_agentset.py b/tests/test_agentset.py index 0c849abe..dd37a20d 100644 --- a/tests/test_agentset.py +++ b/tests/test_agentset.py @@ -92,8 +92,13 @@ def test__init__(self): agents = ExampleAgentSetPolars(model) agents.add({"age": [0, 1, 2, 3]}) assert agents.model == model - assert isinstance(agents.df, pl.DataFrame) - assert agents.df["age"].to_list() == [0, 1, 2, 3] + assert isinstance(agents.df, pl.LazyFrame) + assert agents.df.select("unique_id").collect()["unique_id"].to_list() == [ + 0, + 1, + 2, + 3, + ] assert isinstance(agents._mask, pl.Series) assert isinstance(agents.random, Generator) assert agents.starting_wealth.to_list() == [1, 2, 3, 4] @@ -108,18 +113,58 @@ def test_add( result = agents.add( pl.DataFrame({"wealth": [5, 6], "age": [50, 60]}), inplace=False ) - assert result.df["wealth"].to_list() == [1, 2, 3, 4, 5, 6] - assert result.df["age"].to_list() == [10, 20, 30, 40, 50, 60] + assert result.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + assert result.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 50, + 60, + ] # Test with a list (Sequence[Any]) result = agents.add([5, 10], inplace=False) - assert result.df["wealth"].to_list() == [1, 2, 3, 4, 5] - assert result.df["age"].to_list() == [10, 20, 30, 40, 10] + assert result.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] + assert result.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 10, + ] # Test with a dict[str, Any] agents.add({"wealth": [5, 6], "age": [50, 60]}) - assert agents.df["wealth"].to_list() == [1, 2, 3, 4, 5, 6] - assert agents.df["age"].to_list() == [10, 20, 30, 40, 50, 60] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + assert agents.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 50, + 60, + ] # Test ValueError for dictionary with unique_id key (Line 131) with pytest.raises( @@ -180,16 +225,23 @@ def test_discard(self, fix1_AgentSetPolars_with_pos: ExampleAgentSetPolars): # Test with a single value result = agents.discard(agents["unique_id"][0], inplace=False) - assert all(result["unique_id"] == agents["unique_id"][1, 2, 3]) - assert all(result.pos["unique_id"] == agents["unique_id"][1, 2, 3]) + assert result.df.select("unique_id").collect()["unique_id"].to_list() == [ + 1, + 2, + 3, + ] + assert result.pos["unique_id"].to_list() == [1, 2, 3] assert result.pos["dim_0"].to_list() == [1, None, None] assert result.pos["dim_1"].to_list() == [1, None, None] result += pl.DataFrame({"wealth": 1, "age": 10}) # Test with a list result = agents.discard(agents["unique_id"][0, 1], inplace=False) - assert all(result["unique_id"] == agents["unique_id"][2, 3]) - assert all(result.pos["unique_id"] == agents["unique_id"][2, 3]) + assert result.df.select("unique_id").collect()["unique_id"].to_list() == [ + 2, + 3, + ] + assert result.pos["unique_id"].to_list() == [2, 3] assert result.pos["dim_0"].to_list() == [None, None] assert result.pos["dim_1"].to_list() == [None, None] @@ -197,37 +249,63 @@ def test_discard(self, fix1_AgentSetPolars_with_pos: ExampleAgentSetPolars): result = agents.discard( pl.DataFrame({"unique_id": agents["unique_id"][0, 1]}), inplace=False ) - assert all(result["unique_id"] == agents["unique_id"][2, 3]) - assert all(result.pos["unique_id"] == agents["unique_id"][2, 3]) + assert result.df.select("unique_id").collect()["unique_id"].to_list() == [ + 2, + 3, + ] + assert result.pos["unique_id"].to_list() == [2, 3] assert result.pos["dim_0"].to_list() == [None, None] assert result.pos["dim_1"].to_list() == [None, None] # Test with active_agents agents.active_agents = agents["unique_id"][0, 1] result = agents.discard("active", inplace=False) - assert all(result["unique_id"] == agents["unique_id"][2, 3]) - assert all(result.pos["unique_id"] == agents["unique_id"][2, 3]) + assert result.df.select("unique_id").collect()["unique_id"].to_list() == [ + 2, + 3, + ] + assert result.pos["unique_id"].to_list() == [2, 3] assert result.pos["dim_0"].to_list() == [None, None] assert result.pos["dim_1"].to_list() == [None, None] # Test with empty list result = agents.discard([], inplace=False) - assert all(result.df["unique_id"] == agents["unique_id"]) + assert result.df.select("unique_id").collect()["unique_id"].to_list() == [ + 0, + 1, + 2, + 3, + ] def test_do(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with no return_results, no mask agents.do("add_wealth", 1) - assert agents.df["wealth"].to_list() == [2, 3, 4, 5] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 2, + 3, + 4, + 5, + ] # Test with return_results=True, no mask assert agents.do("add_wealth", 1, return_results=True) is None - assert agents.df["wealth"].to_list() == [3, 4, 5, 6] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 3, + 4, + 5, + 6, + ] # Test with a mask agents.do("add_wealth", 1, mask=agents["wealth"] > 3) - assert agents.df["wealth"].to_list() == [3, 5, 6, 7] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 3, + 5, + 6, + 7, + ] def test_get(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars @@ -239,17 +317,25 @@ def test_get(self, fix1_AgentSetPolars: ExampleAgentSetPolars): result = agents.get(["wealth", "age"]) assert isinstance(result, pl.DataFrame) assert result.columns == ["wealth", "age"] - assert result["wealth"].to_list() == agents.df["wealth"].to_list() + assert ( + result["wealth"].to_list() + == agents.df.select("wealth").collect()["wealth"].to_list() + ) # Test with a single attribute and a mask - selected = agents.select(agents.df["wealth"] > 1, inplace=False) + selected = agents.select( + agents.df.select("wealth").collect()["wealth"] > 1, inplace=False + ) assert selected.get("wealth", mask="active").to_list() == [2, 3, 4] def test_remove(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars remaining_agents_id = agents["unique_id"][2, 3] agents.remove(agents["unique_id"][0, 1]) - assert all(agents.df["unique_id"] == remaining_agents_id) + assert agents.df.select("unique_id").collect()["unique_id"].to_list() == [ + 2, + 3, + ] with pytest.raises(KeyError): agents.remove([0]) @@ -259,41 +345,53 @@ def test_select(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with default arguments. Should select all agents selected = agents.select(inplace=False) assert ( - selected.active_agents["wealth"].to_list() == agents.df["wealth"].to_list() + selected.active_agents.select("wealth").collect()["wealth"].to_list() + == agents.df.select("wealth").collect()["wealth"].to_list() ) # Test with a pl.Series[bool] mask = pl.Series("mask", [True, False, True, True], dtype=pl.Boolean) selected = agents.select(mask, inplace=False) - assert all(selected.active_agents["unique_id"] == agents["unique_id"][0, 2, 3]) + assert selected.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [0, 2, 3] # Test with a ListLike mask = agents["unique_id"][0, 2] selected = agents.select(mask, inplace=False) - assert all(selected.active_agents["unique_id"] == agents["unique_id"][0, 2]) + assert selected.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [0, 2] # Test with a pl.DataFrame mask = pl.DataFrame({"unique_id": agents["unique_id"][0, 1]}) selected = agents.select(mask, inplace=False) - assert all(selected.active_agents["unique_id"] == agents["unique_id"][0, 1]) + assert selected.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [0, 1] # Test with filter_func def filter_func(agentset: AgentSetPolars) -> pl.Series: - return agentset.df["wealth"] > 1 + return agentset.df.select("wealth").collect()["wealth"] > 1 selected = agents.select(filter_func=filter_func, inplace=False) - assert all(selected.active_agents["unique_id"] == agents["unique_id"][1, 2, 3]) + assert selected.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [1, 2, 3] # Test with n selected = agents.select(n=3, inplace=False) - assert len(selected.active_agents) == 3 + assert len(selected.active_agents.collect()) == 3 # Test with n, filter_func and mask mask = pl.Series("mask", [True, False, True, True], dtype=pl.Boolean) selected = agents.select(mask, filter_func=filter_func, n=1, inplace=False) assert any( - id in selected.active_agents["unique_id"].to_list() - for id in agents["unique_id"][2, 3] + el + in selected.active_agents.select("unique_id") + .collect()["unique_id"] + .to_list() + for el in [2, 3] ) def test_set(self, fix1_AgentSetPolars: ExampleAgentSetPolars): @@ -301,40 +399,82 @@ def test_set(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with a single attribute result = agents.set("wealth", 0, inplace=False) - assert result.df["wealth"].to_list() == [0, 0, 0, 0] + assert result.df.select("wealth").collect()["wealth"].to_list() == [ + 0, + 0, + 0, + 0, + ] # Test with a list of attributes result = agents.set(["wealth", "age"], 1, inplace=False) - assert result.df["wealth"].to_list() == [1, 1, 1, 1] - assert result.df["age"].to_list() == [1, 1, 1, 1] + assert result.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 1, + 1, + 1, + ] + assert result.df.select("age").collect()["age"].to_list() == [1, 1, 1, 1] # Test with a single attribute and a mask - selected = agents.select(agents.df["wealth"] > 1, inplace=False) + selected = agents.select( + agents.df.select("wealth").collect()["wealth"] > 1, inplace=False + ) selected.set("wealth", 0, mask="active") - assert selected.df["wealth"].to_list() == [1, 0, 0, 0] + assert selected.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 0, + 0, + 0, + ] # Test with a dictionary agents.set({"wealth": 10, "age": 20}) - assert agents.df["wealth"].to_list() == [10, 10, 10, 10] - assert agents.df["age"].to_list() == [20, 20, 20, 20] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 10, + 10, + 10, + 10, + ] + assert agents.df.select("age").collect()["age"].to_list() == [ + 20, + 20, + 20, + 20, + ] # Test with Collection values (Line 213) - using list as Collection result = agents.set("wealth", [100, 200, 300, 400], inplace=False) - assert result.df["wealth"].to_list() == [100, 200, 300, 400] + assert result.df.select("wealth").collect()["wealth"].to_list() == [ + 100, + 200, + 300, + 400, + ] def test_shuffle(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars for _ in range(10): - original_order = agents.df["unique_id"].to_list() + original_order = ( + agents.df.select("unique_id").collect()["unique_id"].to_list() + ) agents.shuffle() - if original_order != agents.df["unique_id"].to_list(): + if ( + original_order + != agents.df.select("unique_id").collect()["unique_id"].to_list() + ): return assert False def test_sort(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars agents.sort("wealth", ascending=False) - assert agents.df["wealth"].to_list() == [4, 3, 2, 1] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 4, + 3, + 2, + 1, + ] def test__add__( self, @@ -344,19 +484,54 @@ def test__add__( # Test with an AgentSetPolars and a DataFrame agents3 = agents + pl.DataFrame({"wealth": [5, 6], "age": [50, 60]}) - assert agents3.df["wealth"].to_list() == [1, 2, 3, 4, 5, 6] - assert agents3.df["age"].to_list() == [10, 20, 30, 40, 50, 60] + assert agents3.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + assert agents3.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 50, + 60, + ] # Test with an AgentSetPolars and a list (Sequence[Any]) agents3 = agents + [5, 5] # unique_id, wealth, age - assert all(agents3.df["unique_id"].to_list()[:-1] == agents["unique_id"]) - assert len(agents3.df) == 5 - assert agents3.df["wealth"].to_list() == [1, 2, 3, 4, 5] - assert agents3.df["age"].to_list() == [10, 20, 30, 40, 5] + assert all( + agents3.df.select("unique_id").collect()["unique_id"].to_list()[:-1] + == agents["unique_id"] + ) + assert len(agents3.df.collect()) == 5 + assert agents3.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] + assert agents3.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 5, + ] # Test with an AgentSetPolars and a dict agents3 = agents + {"age": 10, "wealth": 5} - assert agents3.df["wealth"].to_list() == [1, 2, 3, 4, 5] + assert agents3.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] def test__contains__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with a single value @@ -410,24 +585,50 @@ def test__iadd__( # Test with an AgentSetPolars and a DataFrame agents = deepcopy(fix1_AgentSetPolars) agents += pl.DataFrame({"wealth": [5, 6], "age": [50, 60]}) - assert agents.df["wealth"].to_list() == [1, 2, 3, 4, 5, 6] - assert agents.df["age"].to_list() == [10, 20, 30, 40, 50, 60] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + assert agents.df.select("age").collect()["age"].to_list() == [ + 10, + 20, + 30, + 40, + 50, + 60, + ] # Test with an AgentSetPolars and a list agents = deepcopy(fix1_AgentSetPolars) agents += [5, 5] # unique_id, wealth, age assert all( - agents["unique_id"].to_list()[:-1] + agents.df.select("unique_id").collect()["unique_id"].to_list()[:-1] == fix1_AgentSetPolars["unique_id"][0, 1, 2, 3] ) - assert len(agents.df) == 5 - assert agents.df["wealth"].to_list() == [1, 2, 3, 4, 5] - assert agents.df["age"].to_list() == [10, 20, 30, 40, 5] + assert len(agents.df.collect()) == 5 + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] + assert agents.df.select("age").collect()["age"].to_list() == [10, 20, 30, 40, 5] # Test with an AgentSetPolars and a dict agents = deepcopy(fix1_AgentSetPolars) agents += {"age": 10, "wealth": 5} - assert agents.df["wealth"].to_list() == [1, 2, 3, 4, 5] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + 5, + ] def test__iter__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars @@ -439,7 +640,7 @@ def test__isub__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with an AgentSetPolars and a DataFrame agents = deepcopy(fix1_AgentSetPolars) agents -= agents.df - assert agents.df.is_empty() + assert agents.df.collect().is_empty() def test__len__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars @@ -463,21 +664,31 @@ def test__setitem__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with key=str, value=Anyagents agents["wealth"] = 0 - assert agents.df["wealth"].to_list() == [0, 0, 0, 0] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 0, + 0, + 0, + 0, + ] # Test with key=list[str], value=Any agents[["wealth", "age"]] = 1 - assert agents.df["wealth"].to_list() == [1, 1, 1, 1] - assert agents.df["age"].to_list() == [1, 1, 1, 1] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 1, + 1, + 1, + ] + assert agents.df.select("age").collect()["age"].to_list() == [1, 1, 1, 1] # Test with key=tuple, value=Any agents[agents["unique_id"][0], "wealth"] = 5 - assert agents.df["wealth"].to_list() == [5, 1, 1, 1] + assert agents.df.select("wealth").collect()["wealth"].to_list() == [5, 1, 1, 1] # Test with key=AgentMask, value=Any agents[agents["unique_id"][0]] = [9, 99] - assert agents.df.item(0, "wealth") == 9 - assert agents.df.item(0, "age") == 99 + assert agents.df.collect().item(0, "wealth") == 9 + assert agents.df.collect().item(0, "age") == 99 def test__str__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents: ExampleAgentSetPolars = fix1_AgentSetPolars @@ -486,8 +697,13 @@ def test__str__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): def test__sub__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents: ExampleAgentSetPolars = fix1_AgentSetPolars agents2: ExampleAgentSetPolars = agents - agents.df - assert agents2.df.is_empty() - assert agents.df["wealth"].to_list() == [1, 2, 3, 4] + assert agents2.df.collect().is_empty() + assert agents.df.select("wealth").collect()["wealth"].to_list() == [ + 1, + 2, + 3, + 4, + ] def test_get_obj(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars @@ -501,28 +717,39 @@ def test_agents( ): agents = fix1_AgentSetPolars agents2 = fix2_AgentSetPolars - assert isinstance(agents.df, pl.DataFrame) + assert isinstance(agents.df, pl.LazyFrame) # Test agents.setter agents.df = agents2.df - assert all(agents["unique_id"] == agents2["unique_id"][0, 1, 2, 3]) + assert agents.df.select("unique_id").collect()["unique_id"].to_list() == [ + 4, + 5, + 6, + 7, + ] def test_active_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with select - agents.select(agents.df["wealth"] > 2, inplace=True) - assert all(agents.active_agents["unique_id"] == agents["unique_id"][2, 3]) + agents.select(agents.df.select("wealth").collect()["wealth"] > 2, inplace=True) + assert agents.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [2, 3] # Test with active_agents.setter - agents.active_agents = agents.df["wealth"] > 2 - assert all(agents.active_agents["unique_id"] == agents["unique_id"][2, 3]) + agents.active_agents = agents.df.select("wealth").collect()["wealth"] > 2 + assert agents.active_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [2, 3] def test_inactive_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars - agents.select(agents.df["wealth"] > 2, inplace=True) - assert all(agents.inactive_agents["unique_id"] == agents["unique_id"][0, 1]) + agents.select(agents.df.select("wealth").collect()["wealth"] > 2, inplace=True) + assert agents.inactive_agents.select("unique_id").collect()[ + "unique_id" + ].to_list() == [0, 1] def test_pos(self, fix1_AgentSetPolars_with_pos: ExampleAgentSetPolars): pos = fix1_AgentSetPolars_with_pos.pos diff --git a/tests/test_mixin.py b/tests/test_mixin.py index 0ea25793..556ab4ca 100644 --- a/tests/test_mixin.py +++ b/tests/test_mixin.py @@ -12,7 +12,7 @@ def mixin(self): @pytest.fixture def df_0(self): - return pl.DataFrame( + return pl.LazyFrame( { "unique_id": ["x", "y", "z"], "A": [1, 2, 3], @@ -24,7 +24,7 @@ def df_0(self): @pytest.fixture def df_1(self): - return pl.DataFrame( + return pl.LazyFrame( { "unique_id": ["z", "a", "b"], "A": [4, 5, 6], @@ -34,33 +34,33 @@ def df_1(self): }, ) - def test_df_add(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_add(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test adding a DataFrame and a sequence element-wise along the rows (axis='index') - result = mixin._df_add(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_add(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [5, 7, 9] assert result["D"].to_list() == [5, 7, 9] # Test adding a DataFrame and a sequence element-wise along the column (axis='columns') - result = mixin._df_add(df_0[["A", "D"]], [1, 2], axis="columns") + result = mixin._df_add(df_0[["A", "D"]], [1, 2], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [2, 3, 4] assert result["D"].to_list() == [3, 4, 5] # Test adding DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_add( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, 7] assert result["D"].to_list() == [None, None, 4] def test_df_all(self, mixin: PolarsMixin): - df = pl.DataFrame( + df = pl.LazyFrame( { "A": [True, False, True], "B": [True, True, True], @@ -68,28 +68,32 @@ def test_df_all(self, mixin: PolarsMixin): ) # Test with axis='columns' - result = mixin._df_all(df["A", "B"], axis="columns") + result = mixin._df_all(df["A", "B"], axis="columns").collect() assert isinstance(result, pl.Series) assert result.name == "all" assert result.to_list() == [True, False, True] # Test with axis='index' - result = mixin._df_all(df["A", "B"], axis="index") + result = mixin._df_all(df["A", "B"], axis="index").collect() assert isinstance(result, pl.Series) assert result.name == "all" assert result.to_list() == [False, True] - def test_df_and(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_and(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test comparing the DataFrame with a sequence element-wise along the rows (axis='index') - df_0 = df_0.with_columns(F=pl.Series([True, True, False])) - df_1 = df_1.with_columns(F=pl.Series([False, False, True])) - result = mixin._df_and(df_0[["C", "F"]], df_1["F"], axis="index") + df_0_with_f = df_0.with_columns(F=pl.lit([True, True, False])) + df_1_with_f = df_1.with_columns(F=pl.lit([False, False, True])) + result = mixin._df_and( + df_0_with_f[["C", "F"]], df_1_with_f["F"], axis="index" + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [False, False, True] assert result["F"].to_list() == [False, False, False] # Test comparing the DataFrame with a sequence element-wise along the columns (axis='columns') - result = mixin._df_and(df_0[["C", "F"]], [True, False], axis="columns") + result = mixin._df_and( + df_0_with_f[["C", "F"]], [True, False], axis="columns" + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [True, False, True] assert result["F"].to_list() == [False, False, False] @@ -100,22 +104,22 @@ def test_df_and(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame df_1[["unique_id", "C", "F"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [None, False, False] assert result["F"].to_list() == [None, None, False] - def test_df_column_names(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_column_names(self, mixin: PolarsMixin, df_0: pl.LazyFrame): cols = mixin._df_column_names(df_0) assert isinstance(cols, list) assert all(isinstance(c, str) for c in cols) assert set(mixin._df_column_names(df_0)) == {"unique_id", "A", "B", "C", "D"} def test_df_combine_first( - self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame + self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame ): # Test with df_0 and df_1 - result = mixin._df_combine_first(df_0, df_1, "unique_id") + result = mixin._df_combine_first(df_0, df_1, "unique_id").collect() result = result.sort("A") assert isinstance(result, pl.DataFrame) assert set(result.columns) == {"unique_id", "A", "B", "C", "D", "E"} @@ -127,7 +131,7 @@ def test_df_combine_first( assert result["E"].to_list() == [None, None, 1, 2, 3] # Test with df_1 and df_0 - result = mixin._df_combine_first(df_1, df_0, "unique_id") + result = mixin._df_combine_first(df_1, df_0, "unique_id").collect() result = result.sort("E", nulls_last=True) assert isinstance(result, pl.DataFrame) assert set(result.columns) == {"unique_id", "A", "B", "C", "D", "E"} @@ -139,14 +143,14 @@ def test_df_combine_first( assert result["E"].to_list() == [1, 2, 3, None, None] def test_df_concat( - self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame + self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame ): ### Test vertical concatenation ## With DataFrames for ignore_index in [False, True]: vertical = mixin._df_concat( [df_0, df_1], how="vertical", ignore_index=ignore_index - ) + ).collect() assert isinstance(vertical, pl.DataFrame) assert vertical.columns == ["unique_id", "A", "B", "C", "D", "E"] assert len(vertical) == 6 @@ -161,7 +165,7 @@ def test_df_concat( for ignore_index in [True, False]: vertical = mixin._df_concat( [df_0["A"], df_1["A"]], how="vertical", ignore_index=ignore_index - ) + ).collect() assert isinstance(vertical, pl.Series) assert len(vertical) == 6 assert vertical.to_list() == [1, 2, 3, 4, 5, 6] @@ -171,10 +175,10 @@ def test_df_concat( ## With DataFrames # Error With same column names with pytest.raises(pl.exceptions.DuplicateError): - mixin._df_concat([df_0, df_1], how="horizontal") + mixin._df_concat([df_0, df_1], how="horizontal").collect() # With ignore_index = False - df_1 = df_1.rename(lambda c: f"{c}_1") - horizontal = mixin._df_concat([df_0, df_1], how="horizontal") + df_1_renamed = df_1.rename(lambda c: f"{c}_1") + horizontal = mixin._df_concat([df_0, df_1_renamed], how="horizontal").collect() assert isinstance(horizontal, pl.DataFrame) assert horizontal.columns == [ "unique_id", @@ -202,10 +206,10 @@ def test_df_concat( # With ignore_index = True horizontal_ignore_index = mixin._df_concat( - [df_0, df_1], + [df_0, df_1_renamed], how="horizontal", ignore_index=True, - ) + ).collect() assert isinstance(horizontal_ignore_index, pl.DataFrame) assert horizontal_ignore_index.columns == [ "0", @@ -234,8 +238,8 @@ def test_df_concat( ## With Series # With ignore_index = False horizontal = mixin._df_concat( - [df_0["A"], df_1["B_1"]], how="horizontal", ignore_index=False - ) + [df_0["A"], df_1_renamed["B_1"]], how="horizontal", ignore_index=False + ).collect() assert isinstance(horizontal, pl.DataFrame) assert horizontal.columns == ["A", "B_1"] assert len(horizontal) == 3 @@ -244,8 +248,8 @@ def test_df_concat( # With ignore_index = True horizontal = mixin._df_concat( - [df_0["A"], df_1["B_1"]], how="horizontal", ignore_index=True - ) + [df_0["A"], df_1_renamed["B_1"]], how="horizontal", ignore_index=True + ).collect() assert isinstance(horizontal, pl.DataFrame) assert horizontal.columns == ["0", "1"] assert len(horizontal) == 3 @@ -255,7 +259,7 @@ def test_df_concat( def test_df_constructor(self, mixin: PolarsMixin): # Test with dictionary data = {"num": [1, 2, 3], "letter": ["a", "b", "c"]} - df = mixin._df_constructor(data) + df = mixin._df_constructor(data).collect() assert isinstance(df, pl.DataFrame) assert list(df.columns) == ["num", "letter"] @@ -266,7 +270,7 @@ def test_df_constructor(self, mixin: PolarsMixin): data = [[1, "a"], [2, "b"], [3, "c"]] df = mixin._df_constructor( data, columns=["num", "letter"], dtypes={"num": "int64"} - ) + ).collect() assert isinstance(df, pl.DataFrame) assert list(df.columns) == ["num", "letter"] assert df["num"].dtype == pl.Int64 @@ -277,65 +281,65 @@ def test_df_constructor(self, mixin: PolarsMixin): data = {"a": 5} df = mixin._df_constructor( data, index=pl.int_range(5, eager=True), index_cols="index" - ) + ).collect() assert isinstance(df, pl.DataFrame) assert list(df.columns) == ["index", "a"] assert df["a"].to_list() == [5, 5, 5, 5, 5] assert df["index"].to_list() == [0, 1, 2, 3, 4] - def test_df_contains(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_contains(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with list - result = mixin._df_contains(df_0, "A", [5, 2, 3]) + result = mixin._df_contains(df_0, "A", [5, 2, 3]).collect() assert isinstance(result, pl.Series) assert result.name == "contains" assert result.to_list() == [False, True, True] - def test_df_div(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_div(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test dividing the DataFrame by a sequence element-wise along the rows (axis='index') - result = mixin._df_div(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_div(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [0.25, 0.4, 0.5] assert result["D"].to_list() == [0.25, 0.4, 0.5] # Test dividing the DataFrame by a sequence element-wise along the columns (axis='columns') - result = mixin._df_div(df_0[["A", "D"]], [1, 2], axis="columns") + result = mixin._df_div(df_0[["A", "D"]], [1, 2], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [1, 2, 3] assert result["D"].to_list() == [0.5, 1, 1.5] # Test dividing DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_div( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, 0.75] assert result["D"].to_list() == [None, None, 3] - def test_df_drop_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_drop_columns(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with str - dropped = mixin._df_drop_columns(df_0, "A") + dropped = mixin._df_drop_columns(df_0, "A").collect() assert isinstance(dropped, pl.DataFrame) assert dropped.columns == ["unique_id", "B", "C", "D"] # Test with list - dropped = mixin._df_drop_columns(df_0, ["A", "C"]) + dropped = mixin._df_drop_columns(df_0, ["A", "C"]).collect() assert dropped.columns == ["unique_id", "B", "D"] - def test_df_drop_duplicates(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_drop_duplicates(self, mixin: PolarsMixin, df_0: pl.LazyFrame): new_df = pl.concat([df_0, df_0], how="vertical") assert len(new_df) == 6 # Test with all columns - dropped = mixin._df_drop_duplicates(new_df) + dropped = mixin._df_drop_duplicates(new_df).collect() assert isinstance(dropped, pl.DataFrame) assert len(dropped) == 3 assert dropped.columns == ["unique_id", "A", "B", "C", "D"] # Test with subset (str) - other_df = pl.DataFrame( + other_df = pl.LazyFrame( { "unique_id": ["x", "y", "z"], "A": [1, 2, 3], @@ -345,156 +349,164 @@ def test_df_drop_duplicates(self, mixin: PolarsMixin, df_0: pl.DataFrame): }, ) new_df = pl.concat([df_0, other_df], how="vertical") - dropped = mixin._df_drop_duplicates(new_df, subset="unique_id") + dropped = mixin._df_drop_duplicates(new_df, subset="unique_id").collect() assert isinstance(dropped, pl.DataFrame) assert len(dropped) == 3 # Test with subset (list) - dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"]) + dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"]).collect() assert isinstance(dropped, pl.DataFrame) assert len(dropped) == 5 assert dropped.columns == ["unique_id", "A", "B", "C", "D"] assert dropped["B"].to_list() == ["a", "b", "c", "e", "f"] # Test with subset (list) and keep='last' - dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"], keep="last") + dropped = mixin._df_drop_duplicates( + new_df, subset=["A", "C"], keep="last" + ).collect() assert isinstance(dropped, pl.DataFrame) assert len(dropped) == 5 assert dropped.columns == ["unique_id", "A", "B", "C", "D"] assert dropped["B"].to_list() == ["d", "b", "c", "e", "f"] # Test with subset (list) and keep=False - dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"], keep=False) + dropped = mixin._df_drop_duplicates( + new_df, subset=["A", "C"], keep=False + ).collect() assert isinstance(dropped, pl.DataFrame) assert len(dropped) == 4 assert dropped.columns == ["unique_id", "A", "B", "C", "D"] assert dropped["B"].to_list() == ["b", "c", "e", "f"] - def test_df_ge(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_ge(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test comparing the DataFrame with a sequence element-wise along the rows (axis='index') - result = mixin._df_ge(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_ge(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [False, False, False] assert result["D"].to_list() == [False, False, False] # Test comparing the DataFrame with a sequence element-wise along the columns (axis='columns') - result = mixin._df_ge(df_0[["A", "D"]], [1, 2], axis="columns") + result = mixin._df_ge(df_0[["A", "D"]], [1, 2], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [True, True, True] assert result["D"].to_list() == [False, True, True] # Test comparing DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_ge( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, False] assert result["D"].to_list() == [None, None, True] - def test_df_get_bool_mask(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_get_bool_mask(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with pl.Series[bool] - mask = mixin._df_get_bool_mask(df_0, "A", pl.Series([True, False, True])) + mask = mixin._df_get_bool_mask( + df_0, "A", pl.Series([True, False, True]) + ).collect() assert mask.to_list() == [True, False, True] # Test with DataFrame - mask_df = pl.DataFrame({"A": [1, 3]}) - mask = mixin._df_get_bool_mask(df_0, "A", mask_df) + mask_df = pl.LazyFrame({"A": [1, 3]}) + mask = mixin._df_get_bool_mask(df_0, "A", mask_df).collect() assert mask.to_list() == [True, False, True] # Test with single value - mask = mixin._df_get_bool_mask(df_0, "A", 1) + mask = mixin._df_get_bool_mask(df_0, "A", 1).collect() assert mask.to_list() == [True, False, False] # Test with list of values - mask = mixin._df_get_bool_mask(df_0, "A", [1, 3]) + mask = mixin._df_get_bool_mask(df_0, "A", [1, 3]).collect() assert mask.to_list() == [True, False, True] # Test with negate=True - mask = mixin._df_get_bool_mask(df_0, "A", [1, 3], negate=True) + mask = mixin._df_get_bool_mask(df_0, "A", [1, 3], negate=True).collect() assert mask.to_list() == [False, True, False] - def test_df_get_masked_df(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_get_masked_df(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with pl.Series[bool] - masked_df = mixin._df_get_masked_df(df_0, "A", pl.Series([True, False, True])) + masked_df = mixin._df_get_masked_df( + df_0, "A", pl.Series([True, False, True]) + ).collect() assert masked_df["A"].to_list() == [1, 3] assert masked_df["unique_id"].to_list() == ["x", "z"] # Test with DataFrame - mask_df = pl.DataFrame({"A": [1, 3]}) - masked_df = mixin._df_get_masked_df(df_0, "A", mask_df) + mask_df = pl.LazyFrame({"A": [1, 3]}) + masked_df = mixin._df_get_masked_df(df_0, "A", mask_df).collect() assert masked_df["A"].to_list() == [1, 3] assert masked_df["unique_id"].to_list() == ["x", "z"] # Test with single value - masked_df = mixin._df_get_masked_df(df_0, "A", 1) + masked_df = mixin._df_get_masked_df(df_0, "A", 1).collect() assert masked_df["A"].to_list() == [1] assert masked_df["unique_id"].to_list() == ["x"] # Test with list of values - masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3]) + masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3]).collect() assert masked_df["A"].to_list() == [1, 3] assert masked_df["unique_id"].to_list() == ["x", "z"] # Test with columns - masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3], columns=["B"]) + masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3], columns=["B"]).collect() assert list(masked_df.columns) == ["B"] assert masked_df["B"].to_list() == ["a", "c"] # Test with negate=True - masked = mixin._df_get_masked_df(df_0, "A", [1, 3], negate=True) + masked = mixin._df_get_masked_df(df_0, "A", [1, 3], negate=True).collect() assert len(masked) == 1 - def test_df_groupby_cumcount(self, df_0: pl.DataFrame, mixin: PolarsMixin): - result = mixin._df_groupby_cumcount(df_0, "C") + def test_df_groupby_cumcount(self, df_0: pl.LazyFrame, mixin: PolarsMixin): + result = mixin._df_groupby_cumcount(df_0, "C").collect() assert result.to_list() == [1, 1, 2] - def test_df_index(self, mixin: PolarsMixin, df_0: pl.DataFrame): - index = mixin._df_index(df_0, "unique_id") + def test_df_index(self, mixin: PolarsMixin, df_0: pl.LazyFrame): + index = mixin._df_index(df_0, "unique_id").collect() assert isinstance(index, pl.Series) assert index.to_list() == ["x", "y", "z"] - def test_df_iterator(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_iterator(self, mixin: PolarsMixin, df_0: pl.LazyFrame): iterator = mixin._df_iterator(df_0) first_item = next(iterator) assert first_item == {"unique_id": "x", "A": 1, "B": "a", "C": True, "D": 1} def test_df_join(self, mixin: PolarsMixin): - left = pl.DataFrame({"A": [1, 2], "B": ["a", "b"]}) - right = pl.DataFrame({"A": [1, 3], "C": ["x", "y"]}) + left = pl.LazyFrame({"A": [1, 2], "B": ["a", "b"]}) + right = pl.LazyFrame({"A": [1, 3], "C": ["x", "y"]}) # Test with 'on' (left join) - joined = mixin._df_join(left, right, on="A") + joined = mixin._df_join(left, right, on="A").collect() assert set(joined.columns) == {"A", "B", "C"} assert joined["A"].to_list() == [1, 2] # Test with 'left_on' and 'right_on' (left join) - right_1 = pl.DataFrame({"D": [1, 2], "C": ["x", "y"]}) - joined = mixin._df_join(left, right_1, left_on="A", right_on="D") + right_1 = pl.LazyFrame({"D": [1, 2], "C": ["x", "y"]}) + joined = mixin._df_join(left, right_1, left_on="A", right_on="D").collect() assert set(joined.columns) == {"A", "B", "C"} assert joined["A"].to_list() == [1, 2] # Test with 'right' join - joined = mixin._df_join(left, right, on="A", how="right") + joined = mixin._df_join(left, right, on="A", how="right").collect() assert set(joined.columns) == {"A", "B", "C"} assert joined["A"].to_list() == [1, 3] # Test with 'inner' join - joined = mixin._df_join(left, right, on="A", how="inner") + joined = mixin._df_join(left, right, on="A", how="inner").collect() assert set(joined.columns) == {"A", "B", "C"} assert joined["A"].to_list() == [1] # Test with 'outer' join - joined = mixin._df_join(left, right, on="A", how="outer") + joined = mixin._df_join(left, right, on="A", how="outer").collect() assert set(joined.columns) == {"A", "B", "A_right", "C"} assert joined["A"].to_list() == [1, None, 2] assert joined["A_right"].to_list() == [1, 3, None] # Test with 'cross' join - joined = mixin._df_join(left, right, how="cross") + joined = mixin._df_join(left, right, how="cross").collect() assert set(joined.columns) == {"A", "B", "A_right", "C"} assert len(joined) == 4 assert joined.row(0) == (1, "a", 1, "x") @@ -503,7 +515,7 @@ def test_df_join(self, mixin: PolarsMixin): assert joined.row(3) == (2, "b", 3, "y") # Test with different 'suffix' - joined = mixin._df_join(left, right, suffix="_r", how="cross") + joined = mixin._df_join(left, right, suffix="_r", how="cross").collect() assert set(joined.columns) == {"A", "B", "A_r", "C"} assert len(joined) == 4 assert joined.row(0) == (1, "a", 1, "x") @@ -511,109 +523,113 @@ def test_df_join(self, mixin: PolarsMixin): assert joined.row(2) == (2, "b", 1, "x") assert joined.row(3) == (2, "b", 3, "y") - def test_df_lt(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_lt(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test comparing the DataFrame with a sequence element-wise along the rows (axis='index') - result = mixin._df_lt(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_lt(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [True, True, True] assert result["D"].to_list() == [True, True, True] # Test comparing the DataFrame with a sequence element-wise along the columns (axis='columns') - result = mixin._df_lt(df_0[["A", "D"]], [2, 3], axis="columns") + result = mixin._df_lt(df_0[["A", "D"]], [2, 3], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [True, False, False] assert result["D"].to_list() == [True, True, False] # Test comparing DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_lt( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, True] assert result["D"].to_list() == [None, None, False] - def test_df_mod(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_mod(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test taking the modulo of the DataFrame by a sequence element-wise along the rows (axis='index') - result = mixin._df_mod(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_mod(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [1, 2, 3] assert result["D"].to_list() == [1, 2, 3] # Test taking the modulo of the DataFrame by a sequence element-wise along the columns (axis='columns') - result = mixin._df_mod(df_0[["A", "D"]], [1, 2], axis="columns") + result = mixin._df_mod(df_0[["A", "D"]], [1, 2], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [0, 0, 0] assert result["D"].to_list() == [1, 0, 1] # Test taking the modulo of DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_mod( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, 3] assert result["D"].to_list() == [None, None, 0] - def test_df_mul(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_mul(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test multiplying the DataFrame by a sequence element-wise along the rows (axis='index') - result = mixin._df_mul(df_0[["A", "D"]], df_1["A"], axis="index") + result = mixin._df_mul(df_0[["A", "D"]], df_1["A"], axis="index").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [4, 10, 18] assert result["D"].to_list() == [4, 10, 18] # Test multiplying the DataFrame by a sequence element-wise along the columns (axis='columns') - result = mixin._df_mul(df_0[["A", "D"]], [1, 2], axis="columns") + result = mixin._df_mul(df_0[["A", "D"]], [1, 2], axis="columns").collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [1, 2, 3] assert result["D"].to_list() == [2, 4, 6] # Test multiplying DataFrames with index-column alignment - df_1 = df_1.with_columns(D=pl.col("E")) + df_1_with_d = df_1.with_columns(D=pl.col("E")) result = mixin._df_mul( df_0[["unique_id", "A", "D"]], - df_1[["unique_id", "A", "D"]], + df_1_with_d[["unique_id", "A", "D"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["A"].to_list() == [None, None, 12] assert result["D"].to_list() == [None, None, 3] def test_df_norm(self, mixin: PolarsMixin): - df = pl.DataFrame({"A": [3, 4], "B": [4, 3]}) + df = pl.LazyFrame({"A": [3, 4], "B": [4, 3]}) # If include_cols = False - norm = mixin._df_norm(df) + norm = mixin._df_norm(df).collect() assert isinstance(norm, pl.Series) assert len(norm) == 2 assert norm[0] == 5 assert norm[1] == 5 # If include_cols = True - norm = mixin._df_norm(df, include_cols=True) + norm = mixin._df_norm(df, include_cols=True).collect() assert isinstance(norm, pl.DataFrame) assert len(norm) == 2 assert norm.columns == ["A", "B", "norm"] assert norm.row(0, named=True)["norm"] == 5 assert norm.row(1, named=True)["norm"] == 5 - def test_df_or(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + def test_df_or(self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame): # Test comparing the DataFrame with a sequence element-wise along the rows (axis='index') - df_0 = df_0.with_columns(F=pl.Series([True, True, False])) - df_1 = df_1.with_columns(F=pl.Series([False, False, True])) - result = mixin._df_or(df_0[["C", "F"]], df_1["F"], axis="index") + df_0_with_f = df_0.with_columns(F=pl.lit([True, True, False])) + df_1_with_f = df_1.with_columns(F=pl.lit([False, False, True])) + result = mixin._df_or( + df_0_with_f[["C", "F"]], df_1_with_f["F"], axis="index" + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [True, False, True] assert result["F"].to_list() == [True, True, True] # Test comparing the DataFrame with a sequence element-wise along the columns (axis='columns') - result = mixin._df_or(df_0[["C", "F"]], [True, False], axis="columns") + result = mixin._df_or( + df_0_with_f[["C", "F"]], [True, False], axis="columns" + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [True, True, True] assert result["F"].to_list() == [True, True, False] @@ -624,16 +640,16 @@ def test_df_or(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame) df_1[["unique_id", "C", "F"]], axis="index", index_cols="unique_id", - ) + ).collect() assert isinstance(result, pl.DataFrame) assert result["C"].to_list() == [True, None, True] assert result["F"].to_list() == [True, True, False] def test_df_reindex( - self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame + self, mixin: PolarsMixin, df_0: pl.LazyFrame, df_1: pl.LazyFrame ): # Test with DataFrame - reindexed = mixin._df_reindex(df_0, df_1, "unique_id") + reindexed = mixin._df_reindex(df_0, df_1, "unique_id").collect() assert isinstance(reindexed, pl.DataFrame) assert reindexed["unique_id"].to_list() == ["z", "a", "b"] assert reindexed["A"].to_list() == [3, None, None] @@ -642,7 +658,7 @@ def test_df_reindex( assert reindexed["D"].to_list() == [3, None, None] # Test with list - reindexed = mixin._df_reindex(df_0, ["z", "a", "b"], "unique_id") + reindexed = mixin._df_reindex(df_0, ["z", "a", "b"], "unique_id").collect() assert isinstance(reindexed, pl.DataFrame) assert reindexed["unique_id"].to_list() == ["z", "a", "b"] assert reindexed["A"].to_list() == [3, None, None] @@ -656,7 +672,7 @@ def test_df_reindex( ["z", "a", "b"], new_index_cols="new_index", original_index_cols="unique_id", - ) + ).collect() assert isinstance(reindexed, pl.DataFrame) assert reindexed["new_index"].to_list() == ["z", "a", "b"] assert reindexed["A"].to_list() == [3, None, None] @@ -664,61 +680,63 @@ def test_df_reindex( assert reindexed["C"].to_list() == [True, None, None] assert reindexed["D"].to_list() == [3, None, None] - def test_df_rename_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): - renamed = mixin._df_rename_columns(df_0, ["A", "B"], ["X", "Y"]) + def test_df_rename_columns(self, mixin: PolarsMixin, df_0: pl.LazyFrame): + renamed = mixin._df_rename_columns(df_0, ["A", "B"], ["X", "Y"]).collect() assert renamed.columns == ["unique_id", "X", "Y", "C", "D"] - def test_df_reset_index(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_reset_index(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # with drop = False - new_df = mixin._df_reset_index(df_0) + new_df = mixin._df_reset_index(df_0).collect() assert mixin._df_all(new_df == df_0).all() # with drop = True - new_df = mixin._df_reset_index(df_0, index_cols="unique_id", drop=True) + new_df = mixin._df_reset_index( + df_0, index_cols="unique_id", drop=True + ).collect() assert new_df.columns == ["A", "B", "C", "D"] assert len(new_df) == len(df_0) for col in new_df.columns: assert (new_df[col] == df_0[col]).all() - def test_df_remove(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_remove(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with list - removed = mixin._df_remove(df_0, [1, 3], "A") + removed = mixin._df_remove(df_0, [1, 3], "A").collect() assert len(removed) == 1 assert removed["unique_id"].to_list() == ["y"] - def test_df_sample(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_sample(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with n - sampled = mixin._df_sample(df_0, n=2, seed=42) + sampled = mixin._df_sample(df_0, n=2, seed=42).collect() assert len(sampled) == 2 # Test with frac - sampled = mixin._df_sample(df_0, frac=2 / 3, seed=42) + sampled = mixin._df_sample(df_0, frac=2 / 3, seed=42).collect() assert len(sampled) == 2 # Test with replacement - sampled = mixin._df_sample(df_0, n=4, with_replacement=True, seed=42) + sampled = mixin._df_sample(df_0, n=4, with_replacement=True, seed=42).collect() assert len(sampled) == 4 assert sampled.n_unique() < 4 - def test_df_set_index(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_set_index(self, mixin: PolarsMixin, df_0: pl.LazyFrame): index = pl.int_range(len(df_0), eager=True) - new_df = mixin._df_set_index(df_0, "index", index) + new_df = mixin._df_set_index(df_0, "index", index).collect() assert (new_df["index"] == index).all() - def test_df_with_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): + def test_df_with_columns(self, mixin: PolarsMixin, df_0: pl.LazyFrame): # Test with list new_df = mixin._df_with_columns( df_0, data=[[4, "d"], [5, "e"], [6, "f"]], new_columns=["D", "E"], - ) + ).collect() assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] assert new_df["D"].to_list() == [4, 5, 6] assert new_df["E"].to_list() == ["d", "e", "f"] # Test with pl.DataFrame - second_df = pl.DataFrame({"D": [4, 5, 6], "E": ["d", "e", "f"]}) - new_df = mixin._df_with_columns(df_0, second_df) + second_df = pl.LazyFrame({"D": [4, 5, 6], "E": ["d", "e", "f"]}) + new_df = mixin._df_with_columns(df_0, second_df).collect() assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] assert new_df["D"].to_list() == [4, 5, 6] assert new_df["E"].to_list() == ["d", "e", "f"] @@ -726,18 +744,22 @@ def test_df_with_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): # Test with dictionary new_df = mixin._df_with_columns( df_0, data={"D": [4, 5, 6], "E": ["d", "e", "f"]} - ) + ).collect() assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] assert new_df["D"].to_list() == [4, 5, 6] assert new_df["E"].to_list() == ["d", "e", "f"] # Test with numpy array - new_df = mixin._df_with_columns(df_0, data=np.array([4, 5, 6]), new_columns="D") + new_df = mixin._df_with_columns( + df_0, data=np.array([4, 5, 6]), new_columns="D" + ).collect() assert "D" in new_df.columns assert new_df["D"].to_list() == [4, 5, 6] # Test with pl.Series - new_df = mixin._df_with_columns(df_0, pl.Series([4, 5, 6]), new_columns="D") + new_df = mixin._df_with_columns( + df_0, pl.Series([4, 5, 6]), new_columns="D" + ).collect() assert "D" in new_df.columns assert new_df["D"].to_list() == [4, 5, 6] @@ -756,29 +778,29 @@ def test_srs_contains(self, mixin: PolarsMixin): srs = [1, 2, 3, 4, 5] # Test with single value - result = mixin._srs_contains(srs, 3) + result = mixin._srs_contains(srs, 3).collect() assert result.to_list() == [True] # Test with list - result = mixin._srs_contains(srs, [1, 3, 6]) + result = mixin._srs_contains(srs, [1, 3, 6]).collect() assert result.to_list() == [True, True, False] # Test with numpy array - result = mixin._srs_contains(srs, np.array([1, 3, 6])) + result = mixin._srs_contains(srs, np.array([1, 3, 6])).collect() assert result.to_list() == [True, True, False] def test_srs_range(self, mixin: PolarsMixin): # Test with default step - srs = mixin._srs_range("test", 0, 5) + srs = mixin._srs_range("test", 0, 5).collect() assert srs.name == "test" assert srs.to_list() == [0, 1, 2, 3, 4] # Test with custom step - srs = mixin._srs_range("test", 0, 10, step=2) + srs = mixin._srs_range("test", 0, 10, step=2).collect() assert srs.to_list() == [0, 2, 4, 6, 8] def test_srs_to_df(self, mixin: PolarsMixin): srs = pl.Series("test", [1, 2, 3]) - df = mixin._srs_to_df(srs) + df = mixin._srs_to_df(srs).collect() assert isinstance(df, pl.DataFrame) assert df["test"].to_list() == [1, 2, 3]