Skip to content

Commit 90e2c17

Browse files
committed
Fix: Add unique_id and agent_type to agent reports
1 parent c5be394 commit 90e2c17

File tree

2 files changed

+66
-9
lines changed

2 files changed

+66
-9
lines changed

mesa_frames/concrete/datacollector.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
self,
7373
model: Model,
7474
model_reporters: dict[str, Callable] | None = None,
75-
agent_reporters: dict[str, str] | None = None,
75+
agent_reporters: dict[str, str | Callable] | None = None, # <-- ALLOWS CALLABLE
7676
trigger: Callable[[Any], bool] | None = None,
7777
reset_memory: bool = True,
7878
storage: Literal[
@@ -92,7 +92,10 @@ def __init__(
9292
model_reporters : dict[str, Callable] | None
9393
Functions to collect data at the model level.
9494
agent_reporters : dict[str, str | Callable] | None
95-
Attributes or functions to collect data at the agent level.
95+
(MODIFIED) A dictionary mapping new column names to existing
96+
column names (str) or callables. Callables are not currently
97+
processed by the agent data collector but are allowed for API compatibility.
98+
Example: {"agent_wealth": "wealth", "age_in_years": "age"}
9699
trigger : Callable[[Any], bool] | None
97100
A function(model) -> bool that determines whether to collect data.
98101
reset_memory : bool
@@ -108,10 +111,14 @@ def __init__(
108111
"""
109112
if agent_reporters:
110113
for key, value in agent_reporters.items():
111-
if not isinstance(value, str):
114+
if not isinstance(key, str):
112115
raise TypeError(
113-
f"Agent reporter for '{key}' must be a string (the column name), "
114-
f"not a {type(value)}. Callable reporters are not supported for agents."
116+
f"Agent reporter keys must be strings (the final column name), not a {type(key)}."
117+
)
118+
if not (isinstance(value, str) or callable(value)):
119+
raise TypeError(
120+
f"Agent reporter for '{key}' must be either a string (the source column name) "
121+
f"or a callable (function taking an agent and returning a value), not a {type(value)}."
115122
)
116123

117124
super().__init__(

tests/test_datacollector.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,19 @@ def test__init__(self, fix1_model, postgres_uri):
171171
# model=model, model_reporters={"total_agents": "sum"}
172172
# )
173173

174+
model.test_dc = DataCollector(
175+
model=model,
176+
agent_reporters={"wealth": lambda m: 1}
177+
)
178+
assert model.test_dc is not None
179+
174180
with pytest.raises(
175-
TypeError,
176-
match="Agent reporter for 'wealth' must be a string",
181+
beartype.roar.BeartypeCallHintParamViolation,
182+
match="not instance of str",
177183
):
178184
model.test_dc = DataCollector(
179-
model=model,
180-
agent_reporters={"wealth": lambda m: 1}, # This is now illegal
185+
model=model,
186+
agent_reporters={123: "wealth"}
181187
)
182188

183189
with pytest.raises(
@@ -729,3 +735,47 @@ def test_batch_save(self, fix2_model):
729735
assert sorted(agent_df_step2_batch1["wealth"].to_list()) == sorted(
730736
expected_wealth_s2b1
731737
)
738+
739+
def test_collect_no_agentsets_list(self, fix1_model, caplog):
740+
"""Tests that the collector logs an error and exits gracefully if _agentsets is missing."""
741+
model = fix1_model
742+
del model.sets._agentsets
743+
744+
dc = DataCollector(model=model, agent_reporters={"wealth": "wealth"})
745+
dc.collect()
746+
747+
assert "could not find '_agentsets'" in caplog.text
748+
assert dc.data["agent"].shape == (0, 0)
749+
750+
def test_collect_agent_set_no_df(self, fix1_model, caplog):
751+
"""Tests that the collector logs a warning and skips a set if it has no .df attribute."""
752+
753+
class NoDfSet:
754+
def __init__(self):
755+
self.__class__ = type("NoDfSet", (object,), {})
756+
757+
fix1_model.sets._agentsets.append(NoDfSet())
758+
759+
fix1_model.dc = DataCollector(fix1_model, agent_reporters={"wealth": "wealth", "age": "age"})
760+
fix1_model.dc.collect()
761+
762+
assert "has no 'df' attribute" in caplog.text
763+
assert fix1_model.dc.data["agent"].shape == (12, 7)
764+
765+
def test_collect_df_no_unique_id(self, fix1_model, caplog):
766+
"""Tests that the collector logs a warning and skips a set if its df has no unique_id."""
767+
bad_set = fix1_model.sets._agentsets[0]
768+
bad_set._df = bad_set._df.drop("unique_id")
769+
770+
fix1_model.dc = DataCollector(fix1_model, agent_reporters={"wealth": "wealth", "age": "age"})
771+
fix1_model.dc.collect()
772+
773+
assert "has no 'unique_id' column" in caplog.text
774+
assert fix1_model.dc.data["agent"].shape == (8, 7)
775+
776+
def test_collect_no_matching_reporters(self, fix1_model):
777+
"""Tests that the collector returns an empty frame if no reporters match any columns."""
778+
fix1_model.dc = DataCollector(fix1_model, agent_reporters={"baz": "foo", "qux": "bar"})
779+
fix1_model.dc.collect()
780+
781+
assert fix1_model.dc.data["agent"].shape == (0, 0)

0 commit comments

Comments
 (0)