From d3b61f3a92d8dbfd57b6093567a058eb5c26f113 Mon Sep 17 00:00:00 2001 From: Agrim Gupta Date: Fri, 31 Oct 2025 12:43:16 +0530 Subject: [PATCH 1/9] feat: Adding agent_id --- mesa_frames/concrete/datacollector.py | 55 ++++++++++++++++++++------- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index 2b50c76d..2f68f707 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -172,28 +172,57 @@ def _collect_model_reporters(self, current_model_step: int, batch_id: int): def _collect_agent_reporters(self, current_model_step: int, batch_id: int): """ - Collect agent-level data using the agent_reporters. + Collect agent-level data using the agent_reporters, including unique agent IDs Constructs a LazyFrame with one column per reporter and - includes `step` and `seed` metadata. Appends it to internal storage. + includes + - agent_id : unique identifier for each agent + - step, seed and batch columns for context. + - Columns for all requested agent reporters. """ - agent_data_dict = {} + all_agent_frames = [] + for col_name, reporter in self._agent_reporters.items(): if isinstance(reporter, str): - for k, v in self._model.sets[reporter].items(): - agent_data_dict[col_name + "_" + str(k.__class__.__name__)] = v + agent_set = self._model.sets[reporter] + + if hasattr(agent_set, "df"): + df = agent_set.df.select(["id", col_name]).rename({"id": "agent_id"}) + elif hasattr(agent_set, "to_polars"): + df = agent_set.to_polars().select(["id", col_name]).rename({"id": "agent_id"}) + else: + records = [] + for agent in agent_set.values(): + if hasattr(agent, "unique_id"): + agent_id = agent.unique_id + elif hasattr(agent, "id"): + agent_id = agent.id + else: + agent_id = None + records.append({"agent_id": agent_id, col_name: getattr(agent, col_name, None)}) + df = pl.DataFrame(records) else: - agent_data_dict[col_name] = reporter(self._model) - agent_lazy_frame = pl.LazyFrame(agent_data_dict) - agent_lazy_frame = agent_lazy_frame.with_columns( - [ + df = reporter(self._model) + if not isinstance(df, pl.DataFrame): + raise TypeError(f"Agent reporter {col_name} must return a Polars DataFrame") + + df = df.with_columns([ pl.lit(current_model_step).alias("step"), pl.lit(str(self.seed)).alias("seed"), pl.lit(batch_id).alias("batch"), - ] - ) - self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame)) - + ]) + all_agent_frames.append(df) + + if all_agent_frames: + merged_df = all_agent_frames[0] + for next_df in all_agent_frames[1:]: + if "agent_id" not in next_df.columns: + continue + merged_df = merged_df.join(next_df, on=["agent_id", "step", "seed", "batch"], how="outer") + + agent_lazy_frame = merged_df.lazy() + self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame)) + @property def data(self) -> dict[str, pl.DataFrame]: """ From 25a8b6dd3444bfed7675a4a2b16aed02571fcea6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 Oct 2025 17:23:49 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa_frames/concrete/datacollector.py | 47 +++++++++++++++++++-------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index 2f68f707..54c60ff0 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -175,7 +175,7 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): Collect agent-level data using the agent_reporters, including unique agent IDs Constructs a LazyFrame with one column per reporter and - includes + includes - agent_id : unique identifier for each agent - step, seed and batch columns for context. - Columns for all requested agent reporters. @@ -187,9 +187,15 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): agent_set = self._model.sets[reporter] if hasattr(agent_set, "df"): - df = agent_set.df.select(["id", col_name]).rename({"id": "agent_id"}) + df = agent_set.df.select(["id", col_name]).rename( + {"id": "agent_id"} + ) elif hasattr(agent_set, "to_polars"): - df = agent_set.to_polars().select(["id", col_name]).rename({"id": "agent_id"}) + df = ( + agent_set.to_polars() + .select(["id", col_name]) + .rename({"id": "agent_id"}) + ) else: records = [] for agent in agent_set.values(): @@ -199,18 +205,27 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): agent_id = agent.id else: agent_id = None - records.append({"agent_id": agent_id, col_name: getattr(agent, col_name, None)}) + records.append( + { + "agent_id": agent_id, + col_name: getattr(agent, col_name, None), + } + ) df = pl.DataFrame(records) else: df = reporter(self._model) if not isinstance(df, pl.DataFrame): - raise TypeError(f"Agent reporter {col_name} must return a Polars DataFrame") - - df = df.with_columns([ - pl.lit(current_model_step).alias("step"), - pl.lit(str(self.seed)).alias("seed"), - pl.lit(batch_id).alias("batch"), - ]) + raise TypeError( + f"Agent reporter {col_name} must return a Polars DataFrame" + ) + + df = df.with_columns( + [ + pl.lit(current_model_step).alias("step"), + pl.lit(str(self.seed)).alias("seed"), + pl.lit(batch_id).alias("batch"), + ] + ) all_agent_frames.append(df) if all_agent_frames: @@ -218,11 +233,15 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): for next_df in all_agent_frames[1:]: if "agent_id" not in next_df.columns: continue - merged_df = merged_df.join(next_df, on=["agent_id", "step", "seed", "batch"], how="outer") + merged_df = merged_df.join( + next_df, on=["agent_id", "step", "seed", "batch"], how="outer" + ) agent_lazy_frame = merged_df.lazy() - self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame)) - + self._frames.append( + ("agent", current_model_step, batch_id, agent_lazy_frame) + ) + @property def data(self) -> dict[str, pl.DataFrame]: """ From f1d844c417eefd230e3a6011c280940527ffa352 Mon Sep 17 00:00:00 2001 From: Agrim Gupta Date: Fri, 31 Oct 2025 23:30:53 +0530 Subject: [PATCH 3/9] Fix: Add agent_id handling --- mesa_frames/concrete/datacollector.py | 38 ++++++++++++++++++++------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index 2f68f707..0ae5e08a 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -54,6 +54,7 @@ def step(self): self.dc.flush() """ +from unittest import result import polars as pl import boto3 from urllib.parse import urlparse @@ -193,18 +194,37 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): else: records = [] for agent in agent_set.values(): - if hasattr(agent, "unique_id"): - agent_id = agent.unique_id - elif hasattr(agent, "id"): - agent_id = agent.id - else: - agent_id = None + agent_id = getattr(agent, "unique_id", getattr(agent, "id", None)) records.append({"agent_id": agent_id, col_name: getattr(agent, col_name, None)}) df = pl.DataFrame(records) else: - df = reporter(self._model) - if not isinstance(df, pl.DataFrame): - raise TypeError(f"Agent reporter {col_name} must return a Polars DataFrame") + result = reporter(self._model) + + # Handle Polars DataFrame directly + if isinstance(result, pl.DataFrame): + df = result + elif isinstance(result, list): + df = pl.DataFrame(result) + elif isinstance(result, dict): + df = pl.DataFrame([result]) + + # Handle dict, list, scalar reporters + else: + # Try to build per-agent data if possible + if hasattr(self._model, "agents"): + records = [] + for agent in self._model.agents: + agent_id = getattr(agent, "unique_id", getattr(agent, "id", None)) + value = getattr(agent, col_name, result if not callable(result) else None) + records.append({"agent_id": agent_id, col_name: value}) + df = pl.DataFrame(records) + else: + # Fallback for scalar or model-level reporters + df = pl.DataFrame([{col_name: result}]) + + # Ensure column consistency + if "agent_id" not in df.columns: + df = df.with_columns(pl.lit(None).alias("agent_id")) df = df.with_columns([ pl.lit(current_model_step).alias("step"), From f1d2e1c3bf49a3eef885795fcd48bfad9e24cd6f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 Oct 2025 18:25:05 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa_frames/concrete/datacollector.py | 52 +++++++++++++++++++-------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index ed3cdb45..47d979a2 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -187,14 +187,27 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): agent_set = self._model.sets[reporter] if hasattr(agent_set, "df"): - df = agent_set.df.select(["id", col_name]).rename({"id": "agent_id"}) + df = agent_set.df.select(["id", col_name]).rename( + {"id": "agent_id"} + ) elif hasattr(agent_set, "to_polars"): - df = agent_set.to_polars().select(["id", col_name]).rename({"id": "agent_id"}) + df = ( + agent_set.to_polars() + .select(["id", col_name]) + .rename({"id": "agent_id"}) + ) else: records = [] for agent in agent_set.values(): - agent_id = getattr(agent, "unique_id", getattr(agent, "id", None)) - records.append({"agent_id": agent_id, col_name: getattr(agent, col_name, None)}) + agent_id = getattr( + agent, "unique_id", getattr(agent, "id", None) + ) + records.append( + { + "agent_id": agent_id, + col_name: getattr(agent, col_name, None), + } + ) df = pl.DataFrame(records) else: @@ -213,8 +226,14 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): if hasattr(self._model, "agents"): records = [] for agent in self._model.agents: - agent_id = getattr(agent, "unique_id", getattr(agent, "id", None)) - value = getattr(agent, col_name, result if not callable(result) else None) + agent_id = getattr( + agent, "unique_id", getattr(agent, "id", None) + ) + value = getattr( + agent, + col_name, + result if not callable(result) else None, + ) records.append({"agent_id": agent_id, col_name: value}) df = pl.DataFrame(records) else: @@ -225,11 +244,13 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): df = df.with_columns(pl.lit(None).alias("agent_id")) ## Add meta columns - df = df.with_columns([ - pl.lit(current_model_step).alias("step"), - pl.lit(str(self.seed)).alias("seed"), - pl.lit(batch_id).alias("batch"), - ]) + df = df.with_columns( + [ + pl.lit(current_model_step).alias("step"), + pl.lit(str(self.seed)).alias("seed"), + pl.lit(batch_id).alias("batch"), + ] + ) all_agent_frames.append(df) if all_agent_frames: @@ -237,11 +258,14 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): for next_df in all_agent_frames[1:]: if "agent_id" not in next_df.columns: continue - merged_df = merged_df.join(next_df, on=["agent_id", "step", "seed", "batch"], how="outer") + merged_df = merged_df.join( + next_df, on=["agent_id", "step", "seed", "batch"], how="outer" + ) agent_lazy_frame = merged_df.lazy() - self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame)) - + self._frames.append( + ("agent", current_model_step, batch_id, agent_lazy_frame) + ) @property def data(self) -> dict[str, pl.DataFrame]: From 6806158ae5b6b9e10b422a46c983d4e256c0382f Mon Sep 17 00:00:00 2001 From: Agrim Gupta Date: Sat, 1 Nov 2025 01:08:54 +0530 Subject: [PATCH 5/9] fix: Changed test file for datacollector --- tests/test_datacollector.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index b7407711..651df5c9 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -185,8 +185,9 @@ def test_collect(self, fix1_model): with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): collected_data["model"]["max_wealth"] - assert collected_data["agent"].shape == (4, 7) + assert collected_data["agent"].shape == (4, 8) assert set(collected_data["agent"].columns) == { + "agent_id", "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", @@ -195,6 +196,7 @@ def test_collect(self, fix1_model): "seed", "batch", } + assert collected_data["agent"]["agent_id"].to_list() == [0, 1, 2, 3] assert collected_data["agent"]["wealth"].to_list() == [1, 2, 3, 4] assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ 10, @@ -242,8 +244,9 @@ def test_collect_step(self, fix1_model): assert collected_data["model"]["step"].to_list() == [5] assert collected_data["model"]["total_agents"].to_list() == [12] - assert collected_data["agent"].shape == (4, 7) + assert collected_data["agent"].shape == (4, 8) assert set(collected_data["agent"].columns) == { + "agent_id", "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", @@ -252,6 +255,7 @@ def test_collect_step(self, fix1_model): "seed", "batch", } + assert collected_data["agent"]["agent_id"].to_list() == [0, 1, 2, 3] assert collected_data["agent"]["wealth"].to_list() == [6, 7, 8, 9] assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ 10, @@ -297,8 +301,9 @@ def test_conditional_collect(self, fix1_model): assert collected_data["model"]["step"].to_list() == [2, 4] assert collected_data["model"]["total_agents"].to_list() == [12, 12] - assert collected_data["agent"].shape == (8, 7) + assert collected_data["agent"].shape == (8, 8) assert set(collected_data["agent"].columns) == { + "agent_id", "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", @@ -308,6 +313,7 @@ def test_conditional_collect(self, fix1_model): "batch", } assert set(collected_data["agent"].columns) == { + "agent_id", "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", @@ -399,6 +405,7 @@ def test_flush_local_csv(self, fix1_model): schema_overrides={"seed": pl.Utf8}, ) assert set(agent_df.columns) == { + "agent_id", "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", @@ -580,8 +587,9 @@ def test_batch_memory(self, fix2_model): assert collected_data["model"]["batch"].to_list() == [0, 1, 0, 1] assert collected_data["model"]["total_agents"].to_list() == [12, 12, 12, 12] - assert collected_data["agent"].shape == (16, 7) + assert collected_data["agent"].shape == (16, 8) assert set(collected_data["agent"].columns) == { + "agent_id", "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", @@ -592,6 +600,7 @@ def test_batch_memory(self, fix2_model): } assert set(collected_data["agent"].columns) == { + "agent_id", "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", @@ -779,6 +788,7 @@ def test_batch_save(self, fix2_model): schema_overrides={"seed": pl.Utf8}, ) assert set(agent_df_step2_batch0.columns) == { + "agent_id", "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", @@ -813,6 +823,7 @@ def test_batch_save(self, fix2_model): schema_overrides={"seed": pl.Utf8}, ) assert set(agent_df_step2_batch1.columns) == { + "agent_id", "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", @@ -847,6 +858,7 @@ def test_batch_save(self, fix2_model): schema_overrides={"seed": pl.Utf8}, ) assert set(agent_df_step4_batch0.columns) == { + "agent_id", "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", From 5c1525f184a8f8a36013903113c3c19e1c6952d5 Mon Sep 17 00:00:00 2001 From: Agrim Gupta Date: Fri, 7 Nov 2025 02:17:03 +0530 Subject: [PATCH 6/9] Fix: Add unique_id and agent_type to agent reports --- mesa_frames/concrete/datacollector.py | 176 ++++---- tests/test_datacollector.py | 605 +++++++++----------------- 2 files changed, 281 insertions(+), 500 deletions(-) diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index 47d979a2..6a834071 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -54,7 +54,6 @@ def step(self): self.dc.flush() """ -from unittest import result import polars as pl import boto3 from urllib.parse import urlparse @@ -65,6 +64,7 @@ def step(self): from collections.abc import Callable from mesa_frames import Model from psycopg2.extensions import connection +import logging class DataCollector(AbstractDataCollector): @@ -72,7 +72,7 @@ def __init__( self, model: Model, model_reporters: dict[str, Callable] | None = None, - agent_reporters: dict[str, str | Callable] | None = None, + agent_reporters: dict[str, str] | None = None, trigger: Callable[[Any], bool] | None = None, reset_memory: bool = True, storage: Literal[ @@ -106,6 +106,14 @@ def __init__( max_worker : int Maximum number of worker threads used for flushing collected data asynchronously """ + if agent_reporters: + for key, value in agent_reporters.items(): + if not isinstance(value, str): + raise TypeError( + f"Agent reporter for '{key}' must be a string (the column name), " + f"not a {type(value)}. Callable reporters are not supported for agents." + ) + super().__init__( model=model, model_reporters=model_reporters, @@ -173,99 +181,73 @@ def _collect_model_reporters(self, current_model_step: int, batch_id: int): def _collect_agent_reporters(self, current_model_step: int, batch_id: int): """ - Collect agent-level data using the agent_reporters, including unique agent IDs. + Collect agent-level data using the agent_reporters. - Constructs a LazyFrame with one column per reporter and includes: - - agent_id : unique identifier for each agent - - step, seed, and batch columns for context - - Columns for all requested agent reporters + This method iterates through all AgentSets in the model, selects the + `unique_id` and the requested reporter columns from each AgentSet's + DataFrame, adds an `agent_type` column, and concatenates them + into a single "long" format LazyFrame. """ all_agent_frames = [] + reporter_map = self._agent_reporters + + try: + agent_sets_list = self._model.sets._agentsets + except AttributeError: + logging.error( + "DataCollector could not find '_agentsets' attribute on model.sets. " + "Agent data collection will be skipped." + ) + return - for col_name, reporter in self._agent_reporters.items(): - if isinstance(reporter, str): - agent_set = self._model.sets[reporter] + for agent_set in agent_sets_list: + if not hasattr(agent_set, "df"): + logging.warning( + f"AgentSet {agent_set.__class__.__name__} has no 'df' attribute. Skipping." + ) + continue - if hasattr(agent_set, "df"): - df = agent_set.df.select(["id", col_name]).rename( - {"id": "agent_id"} - ) - elif hasattr(agent_set, "to_polars"): - df = ( - agent_set.to_polars() - .select(["id", col_name]) - .rename({"id": "agent_id"}) - ) - else: - records = [] - for agent in agent_set.values(): - agent_id = getattr( - agent, "unique_id", getattr(agent, "id", None) - ) - records.append( - { - "agent_id": agent_id, - col_name: getattr(agent, col_name, None), - } - ) - df = pl.DataFrame(records) + agent_df = agent_set.df.lazy() + agent_type = agent_set.__class__.__name__ + available_cols = agent_df.columns - else: - result = reporter(self._model) - - ## Case 1: already a DataFrame - if isinstance(result, pl.DataFrame): - df = result - ## Case 2: dict or list -> convert - elif isinstance(result, dict): - df = pl.DataFrame([result]) - elif isinstance(result, list): - df = pl.DataFrame(result) - else: - ## Case 3: scalar or callable reporter - if hasattr(self._model, "agents"): - records = [] - for agent in self._model.agents: - agent_id = getattr( - agent, "unique_id", getattr(agent, "id", None) - ) - value = getattr( - agent, - col_name, - result if not callable(result) else None, - ) - records.append({"agent_id": agent_id, col_name: value}) - df = pl.DataFrame(records) - else: - df = pl.DataFrame([{col_name: result}]) - - ## Ensure agent_id exists - if "agent_id" not in df.columns: - df = df.with_columns(pl.lit(None).alias("agent_id")) - - ## Add meta columns - df = df.with_columns( - [ - pl.lit(current_model_step).alias("step"), - pl.lit(str(self.seed)).alias("seed"), - pl.lit(batch_id).alias("batch"), - ] - ) - all_agent_frames.append(df) - - if all_agent_frames: - merged_df = all_agent_frames[0] - for next_df in all_agent_frames[1:]: - if "agent_id" not in next_df.columns: - continue - merged_df = merged_df.join( - next_df, on=["agent_id", "step", "seed", "batch"], how="outer" + if "unique_id" not in available_cols: + logging.warning( + f"AgentSet {agent_type} 'df' has no 'unique_id' column. Skipping." ) + continue + + cols_to_select = [pl.col("unique_id")] + + for final_name, source_col in reporter_map.items(): + if source_col in available_cols: + ## Add the column, aliasing it if the key is different + cols_to_select.append(pl.col(source_col).alias(final_name)) + + ## Only proceed if we have more than just unique_id + if len(cols_to_select) > 1: + set_frame = agent_df.select(cols_to_select) + ## Add the agent_type column + set_frame = set_frame.with_columns( + pl.lit(agent_type).alias("agent_type") + ) + all_agent_frames.append(set_frame) - agent_lazy_frame = merged_df.lazy() - self._frames.append( - ("agent", current_model_step, batch_id, agent_lazy_frame) - ) + if not all_agent_frames: + return + + ## Combine all agent set DataFrames into one + final_agent_frame = pl.concat(all_agent_frames, how="diagonal_relaxed") + + ## Add metadata and append + final_agent_frame = final_agent_frame.with_columns( + [ + pl.lit(current_model_step).alias("step"), + pl.lit(str(self.seed)).alias("seed"), + pl.lit(batch_id).alias("batch"), + ] + ) + self._frames.append(("agent", current_model_step, batch_id, final_agent_frame)) @property def data(self) -> dict[str, pl.DataFrame]: @@ -534,13 +516,20 @@ def _validate_reporter_table_columns( If any expected columns are missing from the table. """ expected_columns = set() + + ## Add columns required for the new long agent format + if table_name == "agent_data": + expected_columns.add("unique_id") + expected_columns.add("agent_type") + + ## Add all keys from the reporter dict for col_name, required_column in reporter.items(): - if isinstance(required_column, str): - for k, v in self._model.sets[required_column].items(): - expected_columns.add( - (col_name + "_" + str(k.__class__.__name__)).lower() - ) + if table_name == "agent_data": + if isinstance(required_column, str): + expected_columns.add(col_name.lower()) + ## Callables are not supported for agents else: + ## For model, all reporters are callable expected_columns.add(col_name.lower()) query = f""" @@ -560,6 +549,7 @@ def _validate_reporter_table_columns( required_columns = { "step": "Integer", "seed": "Varchar", + "batch": "Integer" } missing_required = { @@ -606,4 +596,4 @@ def _execute_query_with_result(self, conn: connection, query: str) -> list[tuple """ with conn.cursor() as cur: cur.execute(query) - return cur.fetchall() + return cur.fetchall() \ No newline at end of file diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 651df5c9..ac247494 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -9,17 +9,22 @@ def custom_trigger(model): - return model._steps % 2 == 0 + return model.steps % 2 == 0 class ExampleAgentSet1(AgentSet): def __init__(self, model: Model): super().__init__(model) - self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) - self["age"] = pl.Series("age", [10, 20, 30, 40]) + self._df = pl.DataFrame( + { + "unique_id": [101, 102, 103, 104], + "wealth": [1, 2, 3, 4], + "age": [10, 20, 30, 40], + } + ) def add_wealth(self, amount: int) -> None: - self["wealth"] += amount + self.set("wealth", self["wealth"] + amount) def step(self) -> None: self.add_wealth(1) @@ -28,11 +33,16 @@ def step(self) -> None: class ExampleAgentSet2(AgentSet): def __init__(self, model: Model): super().__init__(model) - self["wealth"] = pl.Series("wealth", [10, 20, 30, 40]) - self["age"] = pl.Series("age", [11, 22, 33, 44]) + self._df = pl.DataFrame( + { + "unique_id": [201, 202, 203, 204], + "wealth": [10, 20, 30, 40], + "age": [11, 22, 33, 44], + } + ) def add_wealth(self, amount: int) -> None: - self["wealth"] += amount + self.set("wealth", self["wealth"] + amount) def step(self) -> None: self.add_wealth(2) @@ -41,11 +51,16 @@ def step(self) -> None: class ExampleAgentSet3(AgentSet): def __init__(self, model: Model): super().__init__(model) - self["age"] = pl.Series("age", [1, 2, 3, 4]) - self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) + self._df = pl.DataFrame( + { + "unique_id": [301, 302, 303, 304], + "age": [1, 2, 3, 4], + "wealth": [1, 2, 3, 4], + } + ) def age_agents(self, amount: int) -> None: - self["age"] += amount + self.set("age", self["age"] + amount) def step(self) -> None: self.age_agents(1) @@ -55,6 +70,7 @@ class ExampleModel(Model): def __init__(self, sets: AgentSetRegistry): super().__init__() self.sets = sets + self._steps = 0 def step(self): self.sets.do("step") @@ -78,6 +94,7 @@ class ExampleModelWithMultipleCollects(Model): def __init__(self, agents: AgentSetRegistry): super().__init__() self.sets = agents + self._seed = 0 def step(self): self.dc.conditional_collect() @@ -96,17 +113,17 @@ def postgres_uri(): @pytest.fixture def fix1_AgentSet() -> ExampleAgentSet1: - return ExampleAgentSet1(Model()) + return ExampleAgentSet1(Model(seed = 1)) @pytest.fixture def fix2_AgentSet() -> ExampleAgentSet2: - return ExampleAgentSet2(Model()) + return ExampleAgentSet2(Model(seed = 1)) @pytest.fixture def fix3_AgentSet() -> ExampleAgentSet3: - return ExampleAgentSet3(Model()) + return ExampleAgentSet3(Model(seed = 1)) @pytest.fixture @@ -117,30 +134,52 @@ def fix_AgentSetRegistry( ) -> AgentSetRegistry: model = Model() agents = AgentSetRegistry(model) - agents.add([fix1_AgentSet, fix2_AgentSet, fix3_AgentSet]) + agents._agentsets = [fix1_AgentSet, fix2_AgentSet, fix3_AgentSet] + # Manually update model link for agent sets + fix1_AgentSet._model = model + fix2_AgentSet._model = model + fix3_AgentSet._model = model return agents @pytest.fixture def fix1_model(fix_AgentSetRegistry: AgentSetRegistry) -> ExampleModel: - return ExampleModel(fix_AgentSetRegistry) + model = ExampleModel(fix_AgentSetRegistry) + fix_AgentSetRegistry._model = model + for s in fix_AgentSetRegistry._agentsets: + s._model = model + return model @pytest.fixture def fix2_model(fix_AgentSetRegistry: AgentSetRegistry) -> ExampleModel: - return ExampleModelWithMultipleCollects(fix_AgentSetRegistry) + model = ExampleModelWithMultipleCollects(fix_AgentSetRegistry) + fix_AgentSetRegistry._model = model + for s in fix_AgentSetRegistry._agentsets: + s._model = model + return model class TestDataCollector: def test__init__(self, fix1_model, postgres_uri): model = fix1_model + # with pytest.raises( + # beartype.roar.BeartypeCallHintParamViolation, + # match="not instance of .*Callable", + # ): + # model.test_dc = DataCollector( + # model=model, model_reporters={"total_agents": "sum"} + # ) + with pytest.raises( - beartype.roar.BeartypeCallHintParamViolation, - match="not instance of .*Callable", + TypeError, + match="Agent reporter for 'wealth' must be a string", ): model.test_dc = DataCollector( - model=model, model_reporters={"total_agents": "sum"} + model=model, + agent_reporters={"wealth": lambda m: 1} # This is now illegal ) + with pytest.raises( ValueError, match="Please define a storage_uri to if to be stored not in memory", @@ -164,7 +203,7 @@ def test_collect(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, ) @@ -185,33 +224,31 @@ def test_collect(self, fix1_model): with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): collected_data["model"]["max_wealth"] - assert collected_data["agent"].shape == (4, 8) - assert set(collected_data["agent"].columns) == { - "agent_id", + agent_df = collected_data["agent"] + + ## 3 agent sets * 4 agents/set = 12 rows + assert agent_df.shape == (12, 7) + assert set(agent_df.columns) == { + "unique_id", + "agent_type", "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", + "age", "step", "seed", "batch", } - assert collected_data["agent"]["agent_id"].to_list() == [0, 1, 2, 3] - assert collected_data["agent"]["wealth"].to_list() == [1, 2, 3, 4] - assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - ] - assert collected_data["agent"]["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - ] - assert collected_data["agent"]["age_ExampleAgentSet3"].to_list() == [1, 2, 3, 4] - assert collected_data["agent"]["step"].to_list() == [0, 0, 0, 0] + + expected_wealth = [1, 2, 3, 4] + [10, 20, 30, 40] + [1, 2, 3, 4] + expected_age = [10, 20, 30, 40] + [11, 22, 33, 44] + [1, 2, 3, 4] + + assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth) + assert sorted(agent_df["age"].to_list()) == sorted(expected_age) + + type_counts = agent_df["agent_type"].value_counts(sort=True) + assert type_counts.filter(pl.col("agent_type") == "ExampleAgentSet1")["count"][0] == 4 + assert type_counts.filter(pl.col("agent_type") == "ExampleAgentSet2")["count"][0] == 4 + assert type_counts.filter(pl.col("agent_type") == "ExampleAgentSet3")["count"][0] == 4 + assert agent_df["step"].to_list() == [0] * 12 with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): collected_data["agent"]["max_wealth"] @@ -225,7 +262,7 @@ def test_collect_step(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, ) @@ -244,33 +281,18 @@ def test_collect_step(self, fix1_model): assert collected_data["model"]["step"].to_list() == [5] assert collected_data["model"]["total_agents"].to_list() == [12] - assert collected_data["agent"].shape == (4, 8) - assert set(collected_data["agent"].columns) == { - "agent_id", - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", + agent_df = collected_data["agent"] + assert agent_df.shape == (12, 7) + assert set(agent_df.columns) == { + "unique_id", "agent_type", "wealth", "age", "step", "seed", "batch" } - assert collected_data["agent"]["agent_id"].to_list() == [0, 1, 2, 3] - assert collected_data["agent"]["wealth"].to_list() == [6, 7, 8, 9] - assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - ] - assert collected_data["agent"]["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - ] - assert collected_data["agent"]["age_ExampleAgentSet3"].to_list() == [6, 7, 8, 9] - assert collected_data["agent"]["step"].to_list() == [5, 5, 5, 5] + + expected_wealth = [6, 7, 8, 9] + [20, 30, 40, 50] + [1, 2, 3, 4] + expected_age = [10, 20, 30, 40] + [11, 22, 33, 44] + [6, 7, 8, 9] + + assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth) + assert sorted(agent_df["age"].to_list()) == sorted(expected_age) + assert agent_df["step"].to_list() == [5] * 12 def test_conditional_collect(self, fix1_model): model = fix1_model @@ -283,7 +305,7 @@ def test_conditional_collect(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, ) @@ -301,59 +323,29 @@ def test_conditional_collect(self, fix1_model): assert collected_data["model"]["step"].to_list() == [2, 4] assert collected_data["model"]["total_agents"].to_list() == [12, 12] - assert collected_data["agent"].shape == (8, 8) - assert set(collected_data["agent"].columns) == { - "agent_id", - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", + agent_df = collected_data["agent"] + + # 12 agents * 2 steps = 24 rows + assert agent_df.shape == (24, 7) + assert set(agent_df.columns) == { + "unique_id", "agent_type", "wealth", "age", "step", "seed", "batch" } - assert set(collected_data["agent"].columns) == { - "agent_id", - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } - assert collected_data["agent"]["wealth"].to_list() == [3, 4, 5, 6, 5, 6, 7, 8] - assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - 10, - 20, - 30, - 40, - ] - assert collected_data["agent"]["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - 11, - 22, - 33, - 44, - ] - assert collected_data["agent"]["age_ExampleAgentSet3"].to_list() == [ - 3, - 4, - 5, - 6, - 5, - 6, - 7, - 8, - ] - assert collected_data["agent"]["step"].to_list() == [2, 2, 2, 2, 4, 4, 4, 4] + + df_step_2 = agent_df.filter(pl.col("step") == 2) + expected_wealth_s2 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4] + expected_age_s2 = [10, 20, 30, 40] + [11, 22, 33, 44] + [3, 4, 5, 6] + + assert df_step_2.shape == (12, 7) + assert sorted(df_step_2["wealth"].to_list()) == sorted(expected_wealth_s2) + assert sorted(df_step_2["age"].to_list()) == sorted(expected_age_s2) + + df_step_4 = agent_df.filter(pl.col("step") == 4) + expected_wealth_s4 = [5, 6, 7, 8] + [18, 28, 38, 48] + [1, 2, 3, 4] + expected_age_s4 = [10, 20, 30, 40] + [11, 22, 33, 44] + [5, 6, 7, 8] + + assert df_step_4.shape == (12, 7) + assert sorted(df_step_4["wealth"].to_list()) == sorted(expected_wealth_s4) + assert sorted(df_step_4["age"].to_list()) == sorted(expected_age_s4) def test_flush_local_csv(self, fix1_model): with tempfile.TemporaryDirectory() as tmpdir: @@ -367,7 +359,7 @@ def test_flush_local_csv(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, storage="csv", @@ -404,33 +396,25 @@ def test_flush_local_csv(self, fix1_model): os.path.join(tmpdir, "agent_step2_batch0.csv"), schema_overrides={"seed": pl.Utf8}, ) + assert agent_df.shape == (12, 7) assert set(agent_df.columns) == { - "agent_id", - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", + "unique_id", "agent_type", "wealth", "age", "step", "seed", "batch" } - assert agent_df["step"].to_list() == [2, 2, 2, 2] - assert agent_df["wealth"].to_list() == [3, 4, 5, 6] - assert agent_df["age_ExampleAgentSet1"].to_list() == [10, 20, 30, 40] - assert agent_df["age_ExampleAgentSet2"].to_list() == [11, 22, 33, 44] - assert agent_df["age_ExampleAgentSet3"].to_list() == [ - 3, - 4, - 5, - 6, - ] - - agent_df = pl.read_csv( + + expected_wealth_s2 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4] + expected_age_s2 = [10, 20, 30, 40] + [11, 22, 33, 44] + [3, 4, 5, 6] + + assert agent_df["step"].to_list() == [2] * 12 + assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth_s2) + assert sorted(agent_df["age"].to_list()) == sorted(expected_age_s2) + + agent_df_s4 = pl.read_csv( os.path.join(tmpdir, "agent_step4_batch0.csv"), schema_overrides={"seed": pl.Utf8}, ) - assert agent_df["step"].to_list() == [4, 4, 4, 4] - assert agent_df["wealth"].to_list() == [5, 6, 7, 8] + expected_wealth_s4 = [5, 6, 7, 8] + [18, 28, 38, 48] + [1, 2, 3, 4] + assert agent_df_s4["step"].to_list() == [4] * 12 + assert sorted(agent_df_s4["wealth"].to_list()) == sorted(expected_wealth_s4) def test_flush_local_parquet(self, fix1_model): with tempfile.TemporaryDirectory() as tmpdir: @@ -444,7 +428,7 @@ def test_flush_local_parquet(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", }, storage="parquet", storage_uri=tmpdir, @@ -473,8 +457,15 @@ def test_flush_local_parquet(self, fix1_model): agent_df = pl.read_parquet( os.path.join(tmpdir, "agent_step0_batch0.parquet") ) - assert agent_df["step"].to_list() == [0, 0, 0, 0] - assert agent_df["wealth"].to_list() == [1, 2, 3, 4] + # 12 rows. 6 cols: unique_id, agent_type, wealth, step, seed, batch + assert agent_df.shape == (12, 6) + assert set(agent_df.columns) == { + "unique_id", "agent_type", "wealth", "step", "seed", "batch" + } + + expected_wealth = [1, 2, 3, 4] + [10, 20, 30, 40] + [1, 2, 3, 4] + assert agent_df["step"].to_list() == [0] * 12 + assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth) @pytest.mark.skipif( os.getenv("SKIP_PG_TESTS") == "true", @@ -485,9 +476,17 @@ def test_postgress(self, fix1_model, postgres_uri): # Connect directly and validate data import psycopg2 - - conn = psycopg2.connect(postgres_uri) + + try: + conn = psycopg2.connect(postgres_uri) + except psycopg2.OperationalError as e: + pytest.skip(f"Could not connect to PostgreSQL: {e}") + cur = conn.cursor() + + ## Cleaning up tables first + cur.execute("DROP TABLE IF EXISTS public.model_data, public.agent_data;") + conn.commit() cur.execute(""" CREATE TABLE public.model_data ( @@ -498,15 +497,16 @@ def test_postgress(self, fix1_model, postgres_uri): ) """) + ## MODIFIED: CREATE TABLE for long format cur.execute(""" CREATE TABLE public.agent_data ( step INTEGER, seed VARCHAR, batch INTEGER, - age_ExampleAgentSet1 INTEGER, - age_ExampleAgentSet2 INTEGER, - age_ExampleAgentSet3 INTEGER, - wealth INTEGER + unique_id BIGINT, + agent_type VARCHAR, + wealth INTEGER, + age INTEGER ) """) conn.commit() @@ -519,8 +519,9 @@ def test_postgress(self, fix1_model, postgres_uri): len(agentset) for agentset in model.sets._agentsets ) }, + ## MODIFIED : long format agent reporters agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, storage="postgresql", @@ -528,7 +529,7 @@ def test_postgress(self, fix1_model, postgres_uri): storage_uri=postgres_uri, ) - model.run_model_with_conditional_collect(4) + model.run_model_with_conditional_collect(4) ## Runs 1,2,3,4. Collects at 2, 4. model.dc.flush() # Connect directly and validate data @@ -544,17 +545,21 @@ def test_postgress(self, fix1_model, postgres_uri): model_rows = cur.fetchall() assert model_rows == [(2, 12), (4, 12)] + # MODIFIED: Check agent data cur.execute( - "SELECT step, batch, wealth,age_ExampleAgentSet1, age_ExampleAgentSet2, age_ExampleAgentSet3 FROM agent_data WHERE step=2 ORDER BY wealth" + "SELECT wealth, age FROM agent_data WHERE step=2 ORDER BY wealth, age" ) agent_rows = cur.fetchall() - assert agent_rows == [ - (2, 0, 3, 10, 11, 3), - (2, 0, 4, 20, 22, 4), - (2, 0, 5, 30, 33, 5), - (2, 0, 6, 40, 44, 6), + + expected_rows_s2 = [ + (1, 3), (2, 4), (3, 5), (3, 10), (4, 6), (4, 20), + (5, 30), (6, 40), (14, 11), (24, 22), (34, 33), (44, 44) ] + + assert sorted(agent_rows) == sorted(expected_rows_s2) + cur.execute("DROP TABLE public.model_data, public.agent_data;") + conn.commit() cur.close() conn.close() @@ -568,141 +573,33 @@ def test_batch_memory(self, fix2_model): len(agentset) for agentset in model.sets._agentsets ) }, - agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], - "age": "age", - }, + agent_reporters={ "wealth": "wealth", "age": "age" }, ) model.run_model_with_conditional_collect_multiple_batch(5) collected_data = model.dc.data + assert collected_data["model"].shape == (4, 4) - assert set(collected_data["model"].columns) == { - "step", - "seed", - "batch", - "total_agents", - } assert collected_data["model"]["step"].to_list() == [2, 2, 4, 4] assert collected_data["model"]["batch"].to_list() == [0, 1, 0, 1] - assert collected_data["model"]["total_agents"].to_list() == [12, 12, 12, 12] - assert collected_data["agent"].shape == (16, 8) - assert set(collected_data["agent"].columns) == { - "agent_id", - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", + agent_df = collected_data["agent"] + assert agent_df.shape == (48, 7) + assert set(agent_df.columns) == { + "unique_id", "agent_type", "wealth", "age", "step", "seed", "batch" } - assert set(collected_data["agent"].columns) == { - "agent_id", - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } + df_s2_b0 = agent_df.filter((pl.col("step") == 2) & (pl.col("batch") == 0)) + expected_wealth_s2b0 = [2, 3, 4, 5] + [12, 22, 32, 42] + [1, 2, 3, 4] + assert sorted(df_s2_b0["wealth"].to_list()) == sorted(expected_wealth_s2b0) - assert collected_data["agent"]["step"].to_list() == [ - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ] - assert collected_data["agent"]["wealth"].to_list() == [ - 2, - 3, - 4, - 5, - 3, - 4, - 5, - 6, - 4, - 5, - 6, - 7, - 5, - 6, - 7, - 8, - ] - assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - 10, - 20, - 30, - 40, - 10, - 20, - 30, - 40, - 10, - 20, - 30, - 40, - ] - assert collected_data["agent"]["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - 11, - 22, - 33, - 44, - 11, - 22, - 33, - 44, - 11, - 22, - 33, - 44, - ] - assert collected_data["agent"]["age_ExampleAgentSet3"].to_list() == [ - 2, - 3, - 4, - 5, - 3, - 4, - 5, - 6, - 4, - 5, - 6, - 7, - 5, - 6, - 7, - 8, - ] + df_s2_b1 = agent_df.filter((pl.col("step") == 2) & (pl.col("batch") == 1)) + expected_wealth_s2b1 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4] + assert sorted(df_s2_b1["wealth"].to_list()) == sorted(expected_wealth_s2b1) - with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): - collected_data["agent"]["max_wealth"] + df_s4_b0 = agent_df.filter((pl.col("step") == 4) & (pl.col("batch") == 0)) + expected_wealth_s4b0 = [4, 5, 6, 7] + [16, 26, 36, 46] + [1, 2, 3, 4] + assert sorted(df_s4_b0["wealth"].to_list()) == sorted(expected_wealth_s4b0) def test_batch_save(self, fix2_model): with tempfile.TemporaryDirectory() as tmpdir: @@ -716,7 +613,7 @@ def test_batch_save(self, fix2_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, storage="csv", @@ -725,9 +622,10 @@ def test_batch_save(self, fix2_model): model.run_model_with_conditional_collect_multiple_batch(5) model.dc.flush() + for _ in range(20): # wait up to ~2 seconds created_files = os.listdir(tmpdir) - if len(created_files) >= 4: + if len(created_files) >= 8: # 4 collects * 2 files/collect = 8 files break time.sleep(0.1) @@ -747,143 +645,36 @@ def test_batch_save(self, fix2_model): os.path.join(tmpdir, "model_step2_batch0.csv"), schema_overrides={"seed": pl.Utf8}, ) - assert set(model_df_step2_batch0.columns) == { - "step", - "seed", - "batch", - "total_agents", - } assert model_df_step2_batch0["step"].to_list() == [2] assert model_df_step2_batch0["total_agents"].to_list() == [12] - model_df_step2_batch0 = pl.read_csv( + model_df_step2_batch1 = pl.read_csv( os.path.join(tmpdir, "model_step2_batch1.csv"), schema_overrides={"seed": pl.Utf8}, ) - assert set(model_df_step2_batch0.columns) == { - "step", - "seed", - "batch", - "total_agents", - } - assert model_df_step2_batch0["step"].to_list() == [2] - assert model_df_step2_batch0["total_agents"].to_list() == [12] - + assert model_df_step2_batch1["step"].to_list() == [2] + assert model_df_step2_batch1["total_agents"].to_list() == [12] + model_df_step4_batch0 = pl.read_csv( os.path.join(tmpdir, "model_step4_batch0.csv"), schema_overrides={"seed": pl.Utf8}, ) - assert set(model_df_step4_batch0.columns) == { - "step", - "seed", - "batch", - "total_agents", - } assert model_df_step4_batch0["step"].to_list() == [4] assert model_df_step4_batch0["total_agents"].to_list() == [12] - + # test agent batch reset agent_df_step2_batch0 = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides={"seed": pl.Utf8, "unique_id": pl.UInt64}, ) - assert set(agent_df_step2_batch0.columns) == { - "agent_id", - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } - assert agent_df_step2_batch0["step"].to_list() == [2, 2, 2, 2] - assert agent_df_step2_batch0["wealth"].to_list() == [2, 3, 4, 5] - assert agent_df_step2_batch0["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - ] - assert agent_df_step2_batch0["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - ] - assert agent_df_step2_batch0["age_ExampleAgentSet3"].to_list() == [ - 2, - 3, - 4, - 5, - ] + + expected_wealth_s2b0 = [2, 3, 4, 5] + [12, 22, 32, 42] + [1, 2, 3, 4] + assert sorted(agent_df_step2_batch0["wealth"].to_list()) == sorted(expected_wealth_s2b0) agent_df_step2_batch1 = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch1.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides={"seed": pl.Utf8, "unique_id": pl.UInt64}, ) - assert set(agent_df_step2_batch1.columns) == { - "agent_id", - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } - assert agent_df_step2_batch1["step"].to_list() == [2, 2, 2, 2] - assert agent_df_step2_batch1["wealth"].to_list() == [3, 4, 5, 6] - assert agent_df_step2_batch1["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - ] - assert agent_df_step2_batch1["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - ] - assert agent_df_step2_batch1["age_ExampleAgentSet3"].to_list() == [ - 3, - 4, - 5, - 6, - ] - - agent_df_step4_batch0 = pl.read_csv( - os.path.join(tmpdir, "agent_step4_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, - ) - assert set(agent_df_step4_batch0.columns) == { - "agent_id", - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } - assert agent_df_step4_batch0["step"].to_list() == [4, 4, 4, 4] - assert agent_df_step4_batch0["wealth"].to_list() == [4, 5, 6, 7] - assert agent_df_step4_batch0["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - ] - assert agent_df_step4_batch0["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - ] - assert agent_df_step4_batch0["age_ExampleAgentSet3"].to_list() == [ - 4, - 5, - 6, - 7, - ] + expected_wealth_s2b1 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4] + assert sorted(agent_df_step2_batch1["wealth"].to_list()) == sorted(expected_wealth_s2b1) + \ No newline at end of file From c5be394686f5c71acde7ca3535e2d24aecae69ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Nov 2025 20:47:23 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa_frames/concrete/datacollector.py | 14 +-- tests/test_datacollector.py | 127 ++++++++++++++++++-------- 2 files changed, 94 insertions(+), 47 deletions(-) diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index 6a834071..f954c62a 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -113,7 +113,7 @@ def __init__( f"Agent reporter for '{key}' must be a string (the column name), " f"not a {type(value)}. Callable reporters are not supported for agents." ) - + super().__init__( model=model, model_reporters=model_reporters, @@ -223,7 +223,7 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): if source_col in available_cols: ## Add the column, aliasing it if the key is different cols_to_select.append(pl.col(source_col).alias(final_name)) - + ## Only proceed if we have more than just unique_id if len(cols_to_select) > 1: set_frame = agent_df.select(cols_to_select) @@ -516,7 +516,7 @@ def _validate_reporter_table_columns( If any expected columns are missing from the table. """ expected_columns = set() - + ## Add columns required for the new long agent format if table_name == "agent_data": expected_columns.add("unique_id") @@ -546,11 +546,7 @@ def _validate_reporter_table_columns( existing_columns = {row[0] for row in result} missing_columns = expected_columns - existing_columns - required_columns = { - "step": "Integer", - "seed": "Varchar", - "batch": "Integer" - } + required_columns = {"step": "Integer", "seed": "Varchar", "batch": "Integer"} missing_required = { col: col_type @@ -596,4 +592,4 @@ def _execute_query_with_result(self, conn: connection, query: str) -> list[tuple """ with conn.cursor() as cur: cur.execute(query) - return cur.fetchall() \ No newline at end of file + return cur.fetchall() diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index ac247494..2f119df0 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -113,17 +113,17 @@ def postgres_uri(): @pytest.fixture def fix1_AgentSet() -> ExampleAgentSet1: - return ExampleAgentSet1(Model(seed = 1)) + return ExampleAgentSet1(Model(seed=1)) @pytest.fixture def fix2_AgentSet() -> ExampleAgentSet2: - return ExampleAgentSet2(Model(seed = 1)) + return ExampleAgentSet2(Model(seed=1)) @pytest.fixture def fix3_AgentSet() -> ExampleAgentSet3: - return ExampleAgentSet3(Model(seed = 1)) + return ExampleAgentSet3(Model(seed=1)) @pytest.fixture @@ -176,8 +176,8 @@ def test__init__(self, fix1_model, postgres_uri): match="Agent reporter for 'wealth' must be a string", ): model.test_dc = DataCollector( - model=model, - agent_reporters={"wealth": lambda m: 1} # This is now illegal + model=model, + agent_reporters={"wealth": lambda m: 1}, # This is now illegal ) with pytest.raises( @@ -225,7 +225,7 @@ def test_collect(self, fix1_model): collected_data["model"]["max_wealth"] agent_df = collected_data["agent"] - + ## 3 agent sets * 4 agents/set = 12 rows assert agent_df.shape == (12, 7) assert set(agent_df.columns) == { @@ -237,17 +237,26 @@ def test_collect(self, fix1_model): "seed", "batch", } - + expected_wealth = [1, 2, 3, 4] + [10, 20, 30, 40] + [1, 2, 3, 4] expected_age = [10, 20, 30, 40] + [11, 22, 33, 44] + [1, 2, 3, 4] - + assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth) assert sorted(agent_df["age"].to_list()) == sorted(expected_age) type_counts = agent_df["agent_type"].value_counts(sort=True) - assert type_counts.filter(pl.col("agent_type") == "ExampleAgentSet1")["count"][0] == 4 - assert type_counts.filter(pl.col("agent_type") == "ExampleAgentSet2")["count"][0] == 4 - assert type_counts.filter(pl.col("agent_type") == "ExampleAgentSet3")["count"][0] == 4 + assert ( + type_counts.filter(pl.col("agent_type") == "ExampleAgentSet1")["count"][0] + == 4 + ) + assert ( + type_counts.filter(pl.col("agent_type") == "ExampleAgentSet2")["count"][0] + == 4 + ) + assert ( + type_counts.filter(pl.col("agent_type") == "ExampleAgentSet3")["count"][0] + == 4 + ) assert agent_df["step"].to_list() == [0] * 12 with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): collected_data["agent"]["max_wealth"] @@ -284,12 +293,18 @@ def test_collect_step(self, fix1_model): agent_df = collected_data["agent"] assert agent_df.shape == (12, 7) assert set(agent_df.columns) == { - "unique_id", "agent_type", "wealth", "age", "step", "seed", "batch" + "unique_id", + "agent_type", + "wealth", + "age", + "step", + "seed", + "batch", } expected_wealth = [6, 7, 8, 9] + [20, 30, 40, 50] + [1, 2, 3, 4] expected_age = [10, 20, 30, 40] + [11, 22, 33, 44] + [6, 7, 8, 9] - + assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth) assert sorted(agent_df["age"].to_list()) == sorted(expected_age) assert agent_df["step"].to_list() == [5] * 12 @@ -324,17 +339,23 @@ def test_conditional_collect(self, fix1_model): assert collected_data["model"]["total_agents"].to_list() == [12, 12] agent_df = collected_data["agent"] - + # 12 agents * 2 steps = 24 rows assert agent_df.shape == (24, 7) assert set(agent_df.columns) == { - "unique_id", "agent_type", "wealth", "age", "step", "seed", "batch" + "unique_id", + "agent_type", + "wealth", + "age", + "step", + "seed", + "batch", } df_step_2 = agent_df.filter(pl.col("step") == 2) expected_wealth_s2 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4] expected_age_s2 = [10, 20, 30, 40] + [11, 22, 33, 44] + [3, 4, 5, 6] - + assert df_step_2.shape == (12, 7) assert sorted(df_step_2["wealth"].to_list()) == sorted(expected_wealth_s2) assert sorted(df_step_2["age"].to_list()) == sorted(expected_age_s2) @@ -398,12 +419,18 @@ def test_flush_local_csv(self, fix1_model): ) assert agent_df.shape == (12, 7) assert set(agent_df.columns) == { - "unique_id", "agent_type", "wealth", "age", "step", "seed", "batch" + "unique_id", + "agent_type", + "wealth", + "age", + "step", + "seed", + "batch", } - + expected_wealth_s2 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4] expected_age_s2 = [10, 20, 30, 40] + [11, 22, 33, 44] + [3, 4, 5, 6] - + assert agent_df["step"].to_list() == [2] * 12 assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth_s2) assert sorted(agent_df["age"].to_list()) == sorted(expected_age_s2) @@ -460,9 +487,14 @@ def test_flush_local_parquet(self, fix1_model): # 12 rows. 6 cols: unique_id, agent_type, wealth, step, seed, batch assert agent_df.shape == (12, 6) assert set(agent_df.columns) == { - "unique_id", "agent_type", "wealth", "step", "seed", "batch" + "unique_id", + "agent_type", + "wealth", + "step", + "seed", + "batch", } - + expected_wealth = [1, 2, 3, 4] + [10, 20, 30, 40] + [1, 2, 3, 4] assert agent_df["step"].to_list() == [0] * 12 assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth) @@ -476,14 +508,14 @@ def test_postgress(self, fix1_model, postgres_uri): # Connect directly and validate data import psycopg2 - + try: conn = psycopg2.connect(postgres_uri) except psycopg2.OperationalError as e: pytest.skip(f"Could not connect to PostgreSQL: {e}") - + cur = conn.cursor() - + ## Cleaning up tables first cur.execute("DROP TABLE IF EXISTS public.model_data, public.agent_data;") conn.commit() @@ -529,7 +561,7 @@ def test_postgress(self, fix1_model, postgres_uri): storage_uri=postgres_uri, ) - model.run_model_with_conditional_collect(4) ## Runs 1,2,3,4. Collects at 2, 4. + model.run_model_with_conditional_collect(4) ## Runs 1,2,3,4. Collects at 2, 4. model.dc.flush() # Connect directly and validate data @@ -545,17 +577,27 @@ def test_postgress(self, fix1_model, postgres_uri): model_rows = cur.fetchall() assert model_rows == [(2, 12), (4, 12)] - # MODIFIED: Check agent data + # MODIFIED: Check agent data cur.execute( "SELECT wealth, age FROM agent_data WHERE step=2 ORDER BY wealth, age" ) agent_rows = cur.fetchall() expected_rows_s2 = [ - (1, 3), (2, 4), (3, 5), (3, 10), (4, 6), (4, 20), - (5, 30), (6, 40), (14, 11), (24, 22), (34, 33), (44, 44) + (1, 3), + (2, 4), + (3, 5), + (3, 10), + (4, 6), + (4, 20), + (5, 30), + (6, 40), + (14, 11), + (24, 22), + (34, 33), + (44, 44), ] - + assert sorted(agent_rows) == sorted(expected_rows_s2) cur.execute("DROP TABLE public.model_data, public.agent_data;") @@ -573,7 +615,7 @@ def test_batch_memory(self, fix2_model): len(agentset) for agentset in model.sets._agentsets ) }, - agent_reporters={ "wealth": "wealth", "age": "age" }, + agent_reporters={"wealth": "wealth", "age": "age"}, ) model.run_model_with_conditional_collect_multiple_batch(5) @@ -586,7 +628,13 @@ def test_batch_memory(self, fix2_model): agent_df = collected_data["agent"] assert agent_df.shape == (48, 7) assert set(agent_df.columns) == { - "unique_id", "agent_type", "wealth", "age", "step", "seed", "batch" + "unique_id", + "agent_type", + "wealth", + "age", + "step", + "seed", + "batch", } df_s2_b0 = agent_df.filter((pl.col("step") == 2) & (pl.col("batch") == 0)) @@ -622,10 +670,10 @@ def test_batch_save(self, fix2_model): model.run_model_with_conditional_collect_multiple_batch(5) model.dc.flush() - + for _ in range(20): # wait up to ~2 seconds created_files = os.listdir(tmpdir) - if len(created_files) >= 8: # 4 collects * 2 files/collect = 8 files + if len(created_files) >= 8: # 4 collects * 2 files/collect = 8 files break time.sleep(0.1) @@ -654,14 +702,14 @@ def test_batch_save(self, fix2_model): ) assert model_df_step2_batch1["step"].to_list() == [2] assert model_df_step2_batch1["total_agents"].to_list() == [12] - + model_df_step4_batch0 = pl.read_csv( os.path.join(tmpdir, "model_step4_batch0.csv"), schema_overrides={"seed": pl.Utf8}, ) assert model_df_step4_batch0["step"].to_list() == [4] assert model_df_step4_batch0["total_agents"].to_list() == [12] - + # test agent batch reset agent_df_step2_batch0 = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch0.csv"), @@ -669,12 +717,15 @@ def test_batch_save(self, fix2_model): ) expected_wealth_s2b0 = [2, 3, 4, 5] + [12, 22, 32, 42] + [1, 2, 3, 4] - assert sorted(agent_df_step2_batch0["wealth"].to_list()) == sorted(expected_wealth_s2b0) + assert sorted(agent_df_step2_batch0["wealth"].to_list()) == sorted( + expected_wealth_s2b0 + ) agent_df_step2_batch1 = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch1.csv"), schema_overrides={"seed": pl.Utf8, "unique_id": pl.UInt64}, ) expected_wealth_s2b1 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4] - assert sorted(agent_df_step2_batch1["wealth"].to_list()) == sorted(expected_wealth_s2b1) - \ No newline at end of file + assert sorted(agent_df_step2_batch1["wealth"].to_list()) == sorted( + expected_wealth_s2b1 + ) From 90e2c17ffd45ee55ca35de210b4c1621edd86294 Mon Sep 17 00:00:00 2001 From: Agrim Gupta Date: Fri, 7 Nov 2025 02:32:36 +0530 Subject: [PATCH 8/9] Fix: Add unique_id and agent_type to agent reports --- mesa_frames/concrete/datacollector.py | 17 +++++--- tests/test_datacollector.py | 58 +++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index f954c62a..7f5e3e35 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -72,7 +72,7 @@ def __init__( self, model: Model, model_reporters: dict[str, Callable] | None = None, - agent_reporters: dict[str, str] | None = None, + agent_reporters: dict[str, str | Callable] | None = None, # <-- ALLOWS CALLABLE trigger: Callable[[Any], bool] | None = None, reset_memory: bool = True, storage: Literal[ @@ -92,7 +92,10 @@ def __init__( model_reporters : dict[str, Callable] | None Functions to collect data at the model level. agent_reporters : dict[str, str | Callable] | None - Attributes or functions to collect data at the agent level. + (MODIFIED) A dictionary mapping new column names to existing + column names (str) or callables. Callables are not currently + processed by the agent data collector but are allowed for API compatibility. + Example: {"agent_wealth": "wealth", "age_in_years": "age"} trigger : Callable[[Any], bool] | None A function(model) -> bool that determines whether to collect data. reset_memory : bool @@ -108,10 +111,14 @@ def __init__( """ if agent_reporters: for key, value in agent_reporters.items(): - if not isinstance(value, str): + if not isinstance(key, str): raise TypeError( - f"Agent reporter for '{key}' must be a string (the column name), " - f"not a {type(value)}. Callable reporters are not supported for agents." + f"Agent reporter keys must be strings (the final column name), not a {type(key)}." + ) + if not (isinstance(value, str) or callable(value)): + raise TypeError( + f"Agent reporter for '{key}' must be either a string (the source column name) " + f"or a callable (function taking an agent and returning a value), not a {type(value)}." ) super().__init__( diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 2f119df0..d5e5e5ec 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -171,13 +171,19 @@ def test__init__(self, fix1_model, postgres_uri): # model=model, model_reporters={"total_agents": "sum"} # ) + model.test_dc = DataCollector( + model=model, + agent_reporters={"wealth": lambda m: 1} + ) + assert model.test_dc is not None + with pytest.raises( - TypeError, - match="Agent reporter for 'wealth' must be a string", + beartype.roar.BeartypeCallHintParamViolation, + match="not instance of str", ): model.test_dc = DataCollector( - model=model, - agent_reporters={"wealth": lambda m: 1}, # This is now illegal + model=model, + agent_reporters={123: "wealth"} ) with pytest.raises( @@ -729,3 +735,47 @@ def test_batch_save(self, fix2_model): assert sorted(agent_df_step2_batch1["wealth"].to_list()) == sorted( expected_wealth_s2b1 ) + + def test_collect_no_agentsets_list(self, fix1_model, caplog): + """Tests that the collector logs an error and exits gracefully if _agentsets is missing.""" + model = fix1_model + del model.sets._agentsets + + dc = DataCollector(model=model, agent_reporters={"wealth": "wealth"}) + dc.collect() + + assert "could not find '_agentsets'" in caplog.text + assert dc.data["agent"].shape == (0, 0) + + def test_collect_agent_set_no_df(self, fix1_model, caplog): + """Tests that the collector logs a warning and skips a set if it has no .df attribute.""" + + class NoDfSet: + def __init__(self): + self.__class__ = type("NoDfSet", (object,), {}) + + fix1_model.sets._agentsets.append(NoDfSet()) + + fix1_model.dc = DataCollector(fix1_model, agent_reporters={"wealth": "wealth", "age": "age"}) + fix1_model.dc.collect() + + assert "has no 'df' attribute" in caplog.text + assert fix1_model.dc.data["agent"].shape == (12, 7) + + def test_collect_df_no_unique_id(self, fix1_model, caplog): + """Tests that the collector logs a warning and skips a set if its df has no unique_id.""" + bad_set = fix1_model.sets._agentsets[0] + bad_set._df = bad_set._df.drop("unique_id") + + fix1_model.dc = DataCollector(fix1_model, agent_reporters={"wealth": "wealth", "age": "age"}) + fix1_model.dc.collect() + + assert "has no 'unique_id' column" in caplog.text + assert fix1_model.dc.data["agent"].shape == (8, 7) + + def test_collect_no_matching_reporters(self, fix1_model): + """Tests that the collector returns an empty frame if no reporters match any columns.""" + fix1_model.dc = DataCollector(fix1_model, agent_reporters={"baz": "foo", "qux": "bar"}) + fix1_model.dc.collect() + + assert fix1_model.dc.data["agent"].shape == (0, 0) \ No newline at end of file From 622af2699c0e1c509b1e975b201ec65f6cff4b23 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Nov 2025 09:11:49 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa_frames/concrete/datacollector.py | 2 +- tests/test_datacollector.py | 42 ++++++++++++++------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index 7f5e3e35..24904cce 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -92,7 +92,7 @@ def __init__( model_reporters : dict[str, Callable] | None Functions to collect data at the model level. agent_reporters : dict[str, str | Callable] | None - (MODIFIED) A dictionary mapping new column names to existing + (MODIFIED) A dictionary mapping new column names to existing column names (str) or callables. Callables are not currently processed by the agent data collector but are allowed for API compatibility. Example: {"agent_wealth": "wealth", "age_in_years": "age"} diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index d5e5e5ec..eeda86b4 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -172,19 +172,15 @@ def test__init__(self, fix1_model, postgres_uri): # ) model.test_dc = DataCollector( - model=model, - agent_reporters={"wealth": lambda m: 1} + model=model, agent_reporters={"wealth": lambda m: 1} ) assert model.test_dc is not None with pytest.raises( beartype.roar.BeartypeCallHintParamViolation, - match="not instance of str", + match="not instance of str", ): - model.test_dc = DataCollector( - model=model, - agent_reporters={123: "wealth"} - ) + model.test_dc = DataCollector(model=model, agent_reporters={123: "wealth"}) with pytest.raises( ValueError, @@ -743,7 +739,7 @@ def test_collect_no_agentsets_list(self, fix1_model, caplog): dc = DataCollector(model=model, agent_reporters={"wealth": "wealth"}) dc.collect() - + assert "could not find '_agentsets'" in caplog.text assert dc.data["agent"].shape == (0, 0) @@ -753,29 +749,35 @@ def test_collect_agent_set_no_df(self, fix1_model, caplog): class NoDfSet: def __init__(self): self.__class__ = type("NoDfSet", (object,), {}) - + fix1_model.sets._agentsets.append(NoDfSet()) - - fix1_model.dc = DataCollector(fix1_model, agent_reporters={"wealth": "wealth", "age": "age"}) + + fix1_model.dc = DataCollector( + fix1_model, agent_reporters={"wealth": "wealth", "age": "age"} + ) fix1_model.dc.collect() - + assert "has no 'df' attribute" in caplog.text - assert fix1_model.dc.data["agent"].shape == (12, 7) + assert fix1_model.dc.data["agent"].shape == (12, 7) def test_collect_df_no_unique_id(self, fix1_model, caplog): """Tests that the collector logs a warning and skips a set if its df has no unique_id.""" - bad_set = fix1_model.sets._agentsets[0] - bad_set._df = bad_set._df.drop("unique_id") - - fix1_model.dc = DataCollector(fix1_model, agent_reporters={"wealth": "wealth", "age": "age"}) + bad_set = fix1_model.sets._agentsets[0] + bad_set._df = bad_set._df.drop("unique_id") + + fix1_model.dc = DataCollector( + fix1_model, agent_reporters={"wealth": "wealth", "age": "age"} + ) fix1_model.dc.collect() - + assert "has no 'unique_id' column" in caplog.text assert fix1_model.dc.data["agent"].shape == (8, 7) def test_collect_no_matching_reporters(self, fix1_model): """Tests that the collector returns an empty frame if no reporters match any columns.""" - fix1_model.dc = DataCollector(fix1_model, agent_reporters={"baz": "foo", "qux": "bar"}) + fix1_model.dc = DataCollector( + fix1_model, agent_reporters={"baz": "foo", "qux": "bar"} + ) fix1_model.dc.collect() - assert fix1_model.dc.data["agent"].shape == (0, 0) \ No newline at end of file + assert fix1_model.dc.data["agent"].shape == (0, 0)