diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index 2b50c76d..24904cce 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -64,6 +64,7 @@ def step(self): from collections.abc import Callable from mesa_frames import Model from psycopg2.extensions import connection +import logging class DataCollector(AbstractDataCollector): @@ -71,7 +72,7 @@ def __init__( self, model: Model, model_reporters: dict[str, Callable] | None = None, - agent_reporters: dict[str, str | Callable] | None = None, + agent_reporters: dict[str, str | Callable] | None = None, # <-- ALLOWS CALLABLE trigger: Callable[[Any], bool] | None = None, reset_memory: bool = True, storage: Literal[ @@ -91,7 +92,10 @@ def __init__( model_reporters : dict[str, Callable] | None Functions to collect data at the model level. agent_reporters : dict[str, str | Callable] | None - Attributes or functions to collect data at the agent level. + (MODIFIED) A dictionary mapping new column names to existing + column names (str) or callables. Callables are not currently + processed by the agent data collector but are allowed for API compatibility. + Example: {"agent_wealth": "wealth", "age_in_years": "age"} trigger : Callable[[Any], bool] | None A function(model) -> bool that determines whether to collect data. reset_memory : bool @@ -105,6 +109,18 @@ def __init__( max_worker : int Maximum number of worker threads used for flushing collected data asynchronously """ + if agent_reporters: + for key, value in agent_reporters.items(): + if not isinstance(key, str): + raise TypeError( + f"Agent reporter keys must be strings (the final column name), not a {type(key)}." + ) + if not (isinstance(value, str) or callable(value)): + raise TypeError( + f"Agent reporter for '{key}' must be either a string (the source column name) " + f"or a callable (function taking an agent and returning a value), not a {type(value)}." + ) + super().__init__( model=model, model_reporters=model_reporters, @@ -174,25 +190,71 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): """ Collect agent-level data using the agent_reporters. - Constructs a LazyFrame with one column per reporter and - includes `step` and `seed` metadata. Appends it to internal storage. + This method iterates through all AgentSets in the model, selects the + `unique_id` and the requested reporter columns from each AgentSet's + DataFrame, adds an `agent_type` column, and concatenates them + into a single "long" format LazyFrame. """ - agent_data_dict = {} - for col_name, reporter in self._agent_reporters.items(): - if isinstance(reporter, str): - for k, v in self._model.sets[reporter].items(): - agent_data_dict[col_name + "_" + str(k.__class__.__name__)] = v - else: - agent_data_dict[col_name] = reporter(self._model) - agent_lazy_frame = pl.LazyFrame(agent_data_dict) - agent_lazy_frame = agent_lazy_frame.with_columns( + all_agent_frames = [] + reporter_map = self._agent_reporters + + try: + agent_sets_list = self._model.sets._agentsets + except AttributeError: + logging.error( + "DataCollector could not find '_agentsets' attribute on model.sets. " + "Agent data collection will be skipped." + ) + return + + for agent_set in agent_sets_list: + if not hasattr(agent_set, "df"): + logging.warning( + f"AgentSet {agent_set.__class__.__name__} has no 'df' attribute. Skipping." + ) + continue + + agent_df = agent_set.df.lazy() + agent_type = agent_set.__class__.__name__ + available_cols = agent_df.columns + + if "unique_id" not in available_cols: + logging.warning( + f"AgentSet {agent_type} 'df' has no 'unique_id' column. Skipping." + ) + continue + + cols_to_select = [pl.col("unique_id")] + + for final_name, source_col in reporter_map.items(): + if source_col in available_cols: + ## Add the column, aliasing it if the key is different + cols_to_select.append(pl.col(source_col).alias(final_name)) + + ## Only proceed if we have more than just unique_id + if len(cols_to_select) > 1: + set_frame = agent_df.select(cols_to_select) + ## Add the agent_type column + set_frame = set_frame.with_columns( + pl.lit(agent_type).alias("agent_type") + ) + all_agent_frames.append(set_frame) + + if not all_agent_frames: + return + + ## Combine all agent set DataFrames into one + final_agent_frame = pl.concat(all_agent_frames, how="diagonal_relaxed") + + ## Add metadata and append + final_agent_frame = final_agent_frame.with_columns( [ pl.lit(current_model_step).alias("step"), pl.lit(str(self.seed)).alias("seed"), pl.lit(batch_id).alias("batch"), ] ) - self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame)) + self._frames.append(("agent", current_model_step, batch_id, final_agent_frame)) @property def data(self) -> dict[str, pl.DataFrame]: @@ -461,13 +523,20 @@ def _validate_reporter_table_columns( If any expected columns are missing from the table. """ expected_columns = set() + + ## Add columns required for the new long agent format + if table_name == "agent_data": + expected_columns.add("unique_id") + expected_columns.add("agent_type") + + ## Add all keys from the reporter dict for col_name, required_column in reporter.items(): - if isinstance(required_column, str): - for k, v in self._model.sets[required_column].items(): - expected_columns.add( - (col_name + "_" + str(k.__class__.__name__)).lower() - ) + if table_name == "agent_data": + if isinstance(required_column, str): + expected_columns.add(col_name.lower()) + ## Callables are not supported for agents else: + ## For model, all reporters are callable expected_columns.add(col_name.lower()) query = f""" @@ -484,10 +553,7 @@ def _validate_reporter_table_columns( existing_columns = {row[0] for row in result} missing_columns = expected_columns - existing_columns - required_columns = { - "step": "Integer", - "seed": "Varchar", - } + required_columns = {"step": "Integer", "seed": "Varchar", "batch": "Integer"} missing_required = { col: col_type diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index b7407711..eeda86b4 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -9,17 +9,22 @@ def custom_trigger(model): - return model._steps % 2 == 0 + return model.steps % 2 == 0 class ExampleAgentSet1(AgentSet): def __init__(self, model: Model): super().__init__(model) - self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) - self["age"] = pl.Series("age", [10, 20, 30, 40]) + self._df = pl.DataFrame( + { + "unique_id": [101, 102, 103, 104], + "wealth": [1, 2, 3, 4], + "age": [10, 20, 30, 40], + } + ) def add_wealth(self, amount: int) -> None: - self["wealth"] += amount + self.set("wealth", self["wealth"] + amount) def step(self) -> None: self.add_wealth(1) @@ -28,11 +33,16 @@ def step(self) -> None: class ExampleAgentSet2(AgentSet): def __init__(self, model: Model): super().__init__(model) - self["wealth"] = pl.Series("wealth", [10, 20, 30, 40]) - self["age"] = pl.Series("age", [11, 22, 33, 44]) + self._df = pl.DataFrame( + { + "unique_id": [201, 202, 203, 204], + "wealth": [10, 20, 30, 40], + "age": [11, 22, 33, 44], + } + ) def add_wealth(self, amount: int) -> None: - self["wealth"] += amount + self.set("wealth", self["wealth"] + amount) def step(self) -> None: self.add_wealth(2) @@ -41,11 +51,16 @@ def step(self) -> None: class ExampleAgentSet3(AgentSet): def __init__(self, model: Model): super().__init__(model) - self["age"] = pl.Series("age", [1, 2, 3, 4]) - self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) + self._df = pl.DataFrame( + { + "unique_id": [301, 302, 303, 304], + "age": [1, 2, 3, 4], + "wealth": [1, 2, 3, 4], + } + ) def age_agents(self, amount: int) -> None: - self["age"] += amount + self.set("age", self["age"] + amount) def step(self) -> None: self.age_agents(1) @@ -55,6 +70,7 @@ class ExampleModel(Model): def __init__(self, sets: AgentSetRegistry): super().__init__() self.sets = sets + self._steps = 0 def step(self): self.sets.do("step") @@ -78,6 +94,7 @@ class ExampleModelWithMultipleCollects(Model): def __init__(self, agents: AgentSetRegistry): super().__init__() self.sets = agents + self._seed = 0 def step(self): self.dc.conditional_collect() @@ -96,17 +113,17 @@ def postgres_uri(): @pytest.fixture def fix1_AgentSet() -> ExampleAgentSet1: - return ExampleAgentSet1(Model()) + return ExampleAgentSet1(Model(seed=1)) @pytest.fixture def fix2_AgentSet() -> ExampleAgentSet2: - return ExampleAgentSet2(Model()) + return ExampleAgentSet2(Model(seed=1)) @pytest.fixture def fix3_AgentSet() -> ExampleAgentSet3: - return ExampleAgentSet3(Model()) + return ExampleAgentSet3(Model(seed=1)) @pytest.fixture @@ -117,30 +134,54 @@ def fix_AgentSetRegistry( ) -> AgentSetRegistry: model = Model() agents = AgentSetRegistry(model) - agents.add([fix1_AgentSet, fix2_AgentSet, fix3_AgentSet]) + agents._agentsets = [fix1_AgentSet, fix2_AgentSet, fix3_AgentSet] + # Manually update model link for agent sets + fix1_AgentSet._model = model + fix2_AgentSet._model = model + fix3_AgentSet._model = model return agents @pytest.fixture def fix1_model(fix_AgentSetRegistry: AgentSetRegistry) -> ExampleModel: - return ExampleModel(fix_AgentSetRegistry) + model = ExampleModel(fix_AgentSetRegistry) + fix_AgentSetRegistry._model = model + for s in fix_AgentSetRegistry._agentsets: + s._model = model + return model @pytest.fixture def fix2_model(fix_AgentSetRegistry: AgentSetRegistry) -> ExampleModel: - return ExampleModelWithMultipleCollects(fix_AgentSetRegistry) + model = ExampleModelWithMultipleCollects(fix_AgentSetRegistry) + fix_AgentSetRegistry._model = model + for s in fix_AgentSetRegistry._agentsets: + s._model = model + return model class TestDataCollector: def test__init__(self, fix1_model, postgres_uri): model = fix1_model + # with pytest.raises( + # beartype.roar.BeartypeCallHintParamViolation, + # match="not instance of .*Callable", + # ): + # model.test_dc = DataCollector( + # model=model, model_reporters={"total_agents": "sum"} + # ) + + model.test_dc = DataCollector( + model=model, agent_reporters={"wealth": lambda m: 1} + ) + assert model.test_dc is not None + with pytest.raises( beartype.roar.BeartypeCallHintParamViolation, - match="not instance of .*Callable", + match="not instance of str", ): - model.test_dc = DataCollector( - model=model, model_reporters={"total_agents": "sum"} - ) + model.test_dc = DataCollector(model=model, agent_reporters={123: "wealth"}) + with pytest.raises( ValueError, match="Please define a storage_uri to if to be stored not in memory", @@ -164,7 +205,7 @@ def test_collect(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, ) @@ -185,31 +226,40 @@ def test_collect(self, fix1_model): with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): collected_data["model"]["max_wealth"] - assert collected_data["agent"].shape == (4, 7) - assert set(collected_data["agent"].columns) == { + agent_df = collected_data["agent"] + + ## 3 agent sets * 4 agents/set = 12 rows + assert agent_df.shape == (12, 7) + assert set(agent_df.columns) == { + "unique_id", + "agent_type", "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", + "age", "step", "seed", "batch", } - assert collected_data["agent"]["wealth"].to_list() == [1, 2, 3, 4] - assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - ] - assert collected_data["agent"]["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - ] - assert collected_data["agent"]["age_ExampleAgentSet3"].to_list() == [1, 2, 3, 4] - assert collected_data["agent"]["step"].to_list() == [0, 0, 0, 0] + + expected_wealth = [1, 2, 3, 4] + [10, 20, 30, 40] + [1, 2, 3, 4] + expected_age = [10, 20, 30, 40] + [11, 22, 33, 44] + [1, 2, 3, 4] + + assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth) + assert sorted(agent_df["age"].to_list()) == sorted(expected_age) + + type_counts = agent_df["agent_type"].value_counts(sort=True) + assert ( + type_counts.filter(pl.col("agent_type") == "ExampleAgentSet1")["count"][0] + == 4 + ) + assert ( + type_counts.filter(pl.col("agent_type") == "ExampleAgentSet2")["count"][0] + == 4 + ) + assert ( + type_counts.filter(pl.col("agent_type") == "ExampleAgentSet3")["count"][0] + == 4 + ) + assert agent_df["step"].to_list() == [0] * 12 with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): collected_data["agent"]["max_wealth"] @@ -223,7 +273,7 @@ def test_collect_step(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, ) @@ -242,31 +292,24 @@ def test_collect_step(self, fix1_model): assert collected_data["model"]["step"].to_list() == [5] assert collected_data["model"]["total_agents"].to_list() == [12] - assert collected_data["agent"].shape == (4, 7) - assert set(collected_data["agent"].columns) == { + agent_df = collected_data["agent"] + assert agent_df.shape == (12, 7) + assert set(agent_df.columns) == { + "unique_id", + "agent_type", "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", + "age", "step", "seed", "batch", } - assert collected_data["agent"]["wealth"].to_list() == [6, 7, 8, 9] - assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - ] - assert collected_data["agent"]["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - ] - assert collected_data["agent"]["age_ExampleAgentSet3"].to_list() == [6, 7, 8, 9] - assert collected_data["agent"]["step"].to_list() == [5, 5, 5, 5] + + expected_wealth = [6, 7, 8, 9] + [20, 30, 40, 50] + [1, 2, 3, 4] + expected_age = [10, 20, 30, 40] + [11, 22, 33, 44] + [6, 7, 8, 9] + + assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth) + assert sorted(agent_df["age"].to_list()) == sorted(expected_age) + assert agent_df["step"].to_list() == [5] * 12 def test_conditional_collect(self, fix1_model): model = fix1_model @@ -279,7 +322,7 @@ def test_conditional_collect(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, ) @@ -297,57 +340,35 @@ def test_conditional_collect(self, fix1_model): assert collected_data["model"]["step"].to_list() == [2, 4] assert collected_data["model"]["total_agents"].to_list() == [12, 12] - assert collected_data["agent"].shape == (8, 7) - assert set(collected_data["agent"].columns) == { - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } - assert set(collected_data["agent"].columns) == { + agent_df = collected_data["agent"] + + # 12 agents * 2 steps = 24 rows + assert agent_df.shape == (24, 7) + assert set(agent_df.columns) == { + "unique_id", + "agent_type", "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", + "age", "step", "seed", "batch", } - assert collected_data["agent"]["wealth"].to_list() == [3, 4, 5, 6, 5, 6, 7, 8] - assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - 10, - 20, - 30, - 40, - ] - assert collected_data["agent"]["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - 11, - 22, - 33, - 44, - ] - assert collected_data["agent"]["age_ExampleAgentSet3"].to_list() == [ - 3, - 4, - 5, - 6, - 5, - 6, - 7, - 8, - ] - assert collected_data["agent"]["step"].to_list() == [2, 2, 2, 2, 4, 4, 4, 4] + + df_step_2 = agent_df.filter(pl.col("step") == 2) + expected_wealth_s2 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4] + expected_age_s2 = [10, 20, 30, 40] + [11, 22, 33, 44] + [3, 4, 5, 6] + + assert df_step_2.shape == (12, 7) + assert sorted(df_step_2["wealth"].to_list()) == sorted(expected_wealth_s2) + assert sorted(df_step_2["age"].to_list()) == sorted(expected_age_s2) + + df_step_4 = agent_df.filter(pl.col("step") == 4) + expected_wealth_s4 = [5, 6, 7, 8] + [18, 28, 38, 48] + [1, 2, 3, 4] + expected_age_s4 = [10, 20, 30, 40] + [11, 22, 33, 44] + [5, 6, 7, 8] + + assert df_step_4.shape == (12, 7) + assert sorted(df_step_4["wealth"].to_list()) == sorted(expected_wealth_s4) + assert sorted(df_step_4["age"].to_list()) == sorted(expected_age_s4) def test_flush_local_csv(self, fix1_model): with tempfile.TemporaryDirectory() as tmpdir: @@ -361,7 +382,7 @@ def test_flush_local_csv(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, storage="csv", @@ -398,32 +419,31 @@ def test_flush_local_csv(self, fix1_model): os.path.join(tmpdir, "agent_step2_batch0.csv"), schema_overrides={"seed": pl.Utf8}, ) + assert agent_df.shape == (12, 7) assert set(agent_df.columns) == { + "unique_id", + "agent_type", "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", + "age", "step", "seed", "batch", } - assert agent_df["step"].to_list() == [2, 2, 2, 2] - assert agent_df["wealth"].to_list() == [3, 4, 5, 6] - assert agent_df["age_ExampleAgentSet1"].to_list() == [10, 20, 30, 40] - assert agent_df["age_ExampleAgentSet2"].to_list() == [11, 22, 33, 44] - assert agent_df["age_ExampleAgentSet3"].to_list() == [ - 3, - 4, - 5, - 6, - ] - agent_df = pl.read_csv( + expected_wealth_s2 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4] + expected_age_s2 = [10, 20, 30, 40] + [11, 22, 33, 44] + [3, 4, 5, 6] + + assert agent_df["step"].to_list() == [2] * 12 + assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth_s2) + assert sorted(agent_df["age"].to_list()) == sorted(expected_age_s2) + + agent_df_s4 = pl.read_csv( os.path.join(tmpdir, "agent_step4_batch0.csv"), schema_overrides={"seed": pl.Utf8}, ) - assert agent_df["step"].to_list() == [4, 4, 4, 4] - assert agent_df["wealth"].to_list() == [5, 6, 7, 8] + expected_wealth_s4 = [5, 6, 7, 8] + [18, 28, 38, 48] + [1, 2, 3, 4] + assert agent_df_s4["step"].to_list() == [4] * 12 + assert sorted(agent_df_s4["wealth"].to_list()) == sorted(expected_wealth_s4) def test_flush_local_parquet(self, fix1_model): with tempfile.TemporaryDirectory() as tmpdir: @@ -437,7 +457,7 @@ def test_flush_local_parquet(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", }, storage="parquet", storage_uri=tmpdir, @@ -466,8 +486,20 @@ def test_flush_local_parquet(self, fix1_model): agent_df = pl.read_parquet( os.path.join(tmpdir, "agent_step0_batch0.parquet") ) - assert agent_df["step"].to_list() == [0, 0, 0, 0] - assert agent_df["wealth"].to_list() == [1, 2, 3, 4] + # 12 rows. 6 cols: unique_id, agent_type, wealth, step, seed, batch + assert agent_df.shape == (12, 6) + assert set(agent_df.columns) == { + "unique_id", + "agent_type", + "wealth", + "step", + "seed", + "batch", + } + + expected_wealth = [1, 2, 3, 4] + [10, 20, 30, 40] + [1, 2, 3, 4] + assert agent_df["step"].to_list() == [0] * 12 + assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth) @pytest.mark.skipif( os.getenv("SKIP_PG_TESTS") == "true", @@ -479,9 +511,17 @@ def test_postgress(self, fix1_model, postgres_uri): # Connect directly and validate data import psycopg2 - conn = psycopg2.connect(postgres_uri) + try: + conn = psycopg2.connect(postgres_uri) + except psycopg2.OperationalError as e: + pytest.skip(f"Could not connect to PostgreSQL: {e}") + cur = conn.cursor() + ## Cleaning up tables first + cur.execute("DROP TABLE IF EXISTS public.model_data, public.agent_data;") + conn.commit() + cur.execute(""" CREATE TABLE public.model_data ( step INTEGER, @@ -491,15 +531,16 @@ def test_postgress(self, fix1_model, postgres_uri): ) """) + ## MODIFIED: CREATE TABLE for long format cur.execute(""" CREATE TABLE public.agent_data ( step INTEGER, seed VARCHAR, batch INTEGER, - age_ExampleAgentSet1 INTEGER, - age_ExampleAgentSet2 INTEGER, - age_ExampleAgentSet3 INTEGER, - wealth INTEGER + unique_id BIGINT, + agent_type VARCHAR, + wealth INTEGER, + age INTEGER ) """) conn.commit() @@ -512,8 +553,9 @@ def test_postgress(self, fix1_model, postgres_uri): len(agentset) for agentset in model.sets._agentsets ) }, + ## MODIFIED : long format agent reporters agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, storage="postgresql", @@ -521,7 +563,7 @@ def test_postgress(self, fix1_model, postgres_uri): storage_uri=postgres_uri, ) - model.run_model_with_conditional_collect(4) + model.run_model_with_conditional_collect(4) ## Runs 1,2,3,4. Collects at 2, 4. model.dc.flush() # Connect directly and validate data @@ -537,17 +579,31 @@ def test_postgress(self, fix1_model, postgres_uri): model_rows = cur.fetchall() assert model_rows == [(2, 12), (4, 12)] + # MODIFIED: Check agent data cur.execute( - "SELECT step, batch, wealth,age_ExampleAgentSet1, age_ExampleAgentSet2, age_ExampleAgentSet3 FROM agent_data WHERE step=2 ORDER BY wealth" + "SELECT wealth, age FROM agent_data WHERE step=2 ORDER BY wealth, age" ) agent_rows = cur.fetchall() - assert agent_rows == [ - (2, 0, 3, 10, 11, 3), - (2, 0, 4, 20, 22, 4), - (2, 0, 5, 30, 33, 5), - (2, 0, 6, 40, 44, 6), + + expected_rows_s2 = [ + (1, 3), + (2, 4), + (3, 5), + (3, 10), + (4, 6), + (4, 20), + (5, 30), + (6, 40), + (14, 11), + (24, 22), + (34, 33), + (44, 44), ] + assert sorted(agent_rows) == sorted(expected_rows_s2) + + cur.execute("DROP TABLE public.model_data, public.agent_data;") + conn.commit() cur.close() conn.close() @@ -561,139 +617,39 @@ def test_batch_memory(self, fix2_model): len(agentset) for agentset in model.sets._agentsets ) }, - agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], - "age": "age", - }, + agent_reporters={"wealth": "wealth", "age": "age"}, ) model.run_model_with_conditional_collect_multiple_batch(5) collected_data = model.dc.data + assert collected_data["model"].shape == (4, 4) - assert set(collected_data["model"].columns) == { - "step", - "seed", - "batch", - "total_agents", - } assert collected_data["model"]["step"].to_list() == [2, 2, 4, 4] assert collected_data["model"]["batch"].to_list() == [0, 1, 0, 1] - assert collected_data["model"]["total_agents"].to_list() == [12, 12, 12, 12] - assert collected_data["agent"].shape == (16, 7) - assert set(collected_data["agent"].columns) == { + agent_df = collected_data["agent"] + assert agent_df.shape == (48, 7) + assert set(agent_df.columns) == { + "unique_id", + "agent_type", "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", + "age", "step", "seed", "batch", } - assert set(collected_data["agent"].columns) == { - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } + df_s2_b0 = agent_df.filter((pl.col("step") == 2) & (pl.col("batch") == 0)) + expected_wealth_s2b0 = [2, 3, 4, 5] + [12, 22, 32, 42] + [1, 2, 3, 4] + assert sorted(df_s2_b0["wealth"].to_list()) == sorted(expected_wealth_s2b0) - assert collected_data["agent"]["step"].to_list() == [ - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ] - assert collected_data["agent"]["wealth"].to_list() == [ - 2, - 3, - 4, - 5, - 3, - 4, - 5, - 6, - 4, - 5, - 6, - 7, - 5, - 6, - 7, - 8, - ] - assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - 10, - 20, - 30, - 40, - 10, - 20, - 30, - 40, - 10, - 20, - 30, - 40, - ] - assert collected_data["agent"]["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - 11, - 22, - 33, - 44, - 11, - 22, - 33, - 44, - 11, - 22, - 33, - 44, - ] - assert collected_data["agent"]["age_ExampleAgentSet3"].to_list() == [ - 2, - 3, - 4, - 5, - 3, - 4, - 5, - 6, - 4, - 5, - 6, - 7, - 5, - 6, - 7, - 8, - ] + df_s2_b1 = agent_df.filter((pl.col("step") == 2) & (pl.col("batch") == 1)) + expected_wealth_s2b1 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4] + assert sorted(df_s2_b1["wealth"].to_list()) == sorted(expected_wealth_s2b1) - with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): - collected_data["agent"]["max_wealth"] + df_s4_b0 = agent_df.filter((pl.col("step") == 4) & (pl.col("batch") == 0)) + expected_wealth_s4b0 = [4, 5, 6, 7] + [16, 26, 36, 46] + [1, 2, 3, 4] + assert sorted(df_s4_b0["wealth"].to_list()) == sorted(expected_wealth_s4b0) def test_batch_save(self, fix2_model): with tempfile.TemporaryDirectory() as tmpdir: @@ -707,7 +663,7 @@ def test_batch_save(self, fix2_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": "wealth", "age": "age", }, storage="csv", @@ -716,9 +672,10 @@ def test_batch_save(self, fix2_model): model.run_model_with_conditional_collect_multiple_batch(5) model.dc.flush() + for _ in range(20): # wait up to ~2 seconds created_files = os.listdir(tmpdir) - if len(created_files) >= 4: + if len(created_files) >= 8: # 4 collects * 2 files/collect = 8 files break time.sleep(0.1) @@ -738,140 +695,89 @@ def test_batch_save(self, fix2_model): os.path.join(tmpdir, "model_step2_batch0.csv"), schema_overrides={"seed": pl.Utf8}, ) - assert set(model_df_step2_batch0.columns) == { - "step", - "seed", - "batch", - "total_agents", - } assert model_df_step2_batch0["step"].to_list() == [2] assert model_df_step2_batch0["total_agents"].to_list() == [12] - model_df_step2_batch0 = pl.read_csv( + model_df_step2_batch1 = pl.read_csv( os.path.join(tmpdir, "model_step2_batch1.csv"), schema_overrides={"seed": pl.Utf8}, ) - assert set(model_df_step2_batch0.columns) == { - "step", - "seed", - "batch", - "total_agents", - } - assert model_df_step2_batch0["step"].to_list() == [2] - assert model_df_step2_batch0["total_agents"].to_list() == [12] + assert model_df_step2_batch1["step"].to_list() == [2] + assert model_df_step2_batch1["total_agents"].to_list() == [12] model_df_step4_batch0 = pl.read_csv( os.path.join(tmpdir, "model_step4_batch0.csv"), schema_overrides={"seed": pl.Utf8}, ) - assert set(model_df_step4_batch0.columns) == { - "step", - "seed", - "batch", - "total_agents", - } assert model_df_step4_batch0["step"].to_list() == [4] assert model_df_step4_batch0["total_agents"].to_list() == [12] # test agent batch reset agent_df_step2_batch0 = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides={"seed": pl.Utf8, "unique_id": pl.UInt64}, + ) + + expected_wealth_s2b0 = [2, 3, 4, 5] + [12, 22, 32, 42] + [1, 2, 3, 4] + assert sorted(agent_df_step2_batch0["wealth"].to_list()) == sorted( + expected_wealth_s2b0 ) - assert set(agent_df_step2_batch0.columns) == { - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } - assert agent_df_step2_batch0["step"].to_list() == [2, 2, 2, 2] - assert agent_df_step2_batch0["wealth"].to_list() == [2, 3, 4, 5] - assert agent_df_step2_batch0["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - ] - assert agent_df_step2_batch0["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - ] - assert agent_df_step2_batch0["age_ExampleAgentSet3"].to_list() == [ - 2, - 3, - 4, - 5, - ] agent_df_step2_batch1 = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch1.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides={"seed": pl.Utf8, "unique_id": pl.UInt64}, ) - assert set(agent_df_step2_batch1.columns) == { - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } - assert agent_df_step2_batch1["step"].to_list() == [2, 2, 2, 2] - assert agent_df_step2_batch1["wealth"].to_list() == [3, 4, 5, 6] - assert agent_df_step2_batch1["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - ] - assert agent_df_step2_batch1["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - ] - assert agent_df_step2_batch1["age_ExampleAgentSet3"].to_list() == [ - 3, - 4, - 5, - 6, - ] - - agent_df_step4_batch0 = pl.read_csv( - os.path.join(tmpdir, "agent_step4_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, + expected_wealth_s2b1 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4] + assert sorted(agent_df_step2_batch1["wealth"].to_list()) == sorted( + expected_wealth_s2b1 ) - assert set(agent_df_step4_batch0.columns) == { - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } - assert agent_df_step4_batch0["step"].to_list() == [4, 4, 4, 4] - assert agent_df_step4_batch0["wealth"].to_list() == [4, 5, 6, 7] - assert agent_df_step4_batch0["age_ExampleAgentSet1"].to_list() == [ - 10, - 20, - 30, - 40, - ] - assert agent_df_step4_batch0["age_ExampleAgentSet2"].to_list() == [ - 11, - 22, - 33, - 44, - ] - assert agent_df_step4_batch0["age_ExampleAgentSet3"].to_list() == [ - 4, - 5, - 6, - 7, - ] + + def test_collect_no_agentsets_list(self, fix1_model, caplog): + """Tests that the collector logs an error and exits gracefully if _agentsets is missing.""" + model = fix1_model + del model.sets._agentsets + + dc = DataCollector(model=model, agent_reporters={"wealth": "wealth"}) + dc.collect() + + assert "could not find '_agentsets'" in caplog.text + assert dc.data["agent"].shape == (0, 0) + + def test_collect_agent_set_no_df(self, fix1_model, caplog): + """Tests that the collector logs a warning and skips a set if it has no .df attribute.""" + + class NoDfSet: + def __init__(self): + self.__class__ = type("NoDfSet", (object,), {}) + + fix1_model.sets._agentsets.append(NoDfSet()) + + fix1_model.dc = DataCollector( + fix1_model, agent_reporters={"wealth": "wealth", "age": "age"} + ) + fix1_model.dc.collect() + + assert "has no 'df' attribute" in caplog.text + assert fix1_model.dc.data["agent"].shape == (12, 7) + + def test_collect_df_no_unique_id(self, fix1_model, caplog): + """Tests that the collector logs a warning and skips a set if its df has no unique_id.""" + bad_set = fix1_model.sets._agentsets[0] + bad_set._df = bad_set._df.drop("unique_id") + + fix1_model.dc = DataCollector( + fix1_model, agent_reporters={"wealth": "wealth", "age": "age"} + ) + fix1_model.dc.collect() + + assert "has no 'unique_id' column" in caplog.text + assert fix1_model.dc.data["agent"].shape == (8, 7) + + def test_collect_no_matching_reporters(self, fix1_model): + """Tests that the collector returns an empty frame if no reporters match any columns.""" + fix1_model.dc = DataCollector( + fix1_model, agent_reporters={"baz": "foo", "qux": "bar"} + ) + fix1_model.dc.collect() + + assert fix1_model.dc.data["agent"].shape == (0, 0)