Skip to content

Commit c31817d

Browse files
committed
Fix: Add unique_id and agent_type to agent reports
1 parent c5be394 commit c31817d

File tree

2 files changed

+74
-8
lines changed

2 files changed

+74
-8
lines changed

mesa_frames/concrete/datacollector.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
self,
7373
model: Model,
7474
model_reporters: dict[str, Callable] | None = None,
75-
agent_reporters: dict[str, str] | None = None,
75+
agent_reporters: dict[str, str | Callable] | None = None, # <-- ALLOWS CALLABLE
7676
trigger: Callable[[Any], bool] | None = None,
7777
reset_memory: bool = True,
7878
storage: Literal[
@@ -92,7 +92,10 @@ def __init__(
9292
model_reporters : dict[str, Callable] | None
9393
Functions to collect data at the model level.
9494
agent_reporters : dict[str, str | Callable] | None
95-
Attributes or functions to collect data at the agent level.
95+
(MODIFIED) A dictionary mapping new column names to existing
96+
column names (str) or callables. Callables are not currently
97+
processed by the agent data collector but are allowed for API compatibility.
98+
Example: {"agent_wealth": "wealth", "age_in_years": "age"}
9699
trigger : Callable[[Any], bool] | None
97100
A function(model) -> bool that determines whether to collect data.
98101
reset_memory : bool
@@ -108,10 +111,14 @@ def __init__(
108111
"""
109112
if agent_reporters:
110113
for key, value in agent_reporters.items():
111-
if not isinstance(value, str):
114+
if not isinstance(key, str):
112115
raise TypeError(
113-
f"Agent reporter for '{key}' must be a string (the column name), "
114-
f"not a {type(value)}. Callable reporters are not supported for agents."
116+
f"Agent reporter keys must be strings (the final column name), not a {type(key)}."
117+
)
118+
if not (isinstance(value, str) or callable(value)):
119+
raise TypeError(
120+
f"Agent reporter for '{key}' must be either a string (the source column name) "
121+
f"or a callable (function taking an agent and returning a value), not a {type(value)}."
115122
)
116123

117124
super().__init__(

tests/test_datacollector.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,28 @@ def test__init__(self, fix1_model, postgres_uri):
171171
# model=model, model_reporters={"total_agents": "sum"}
172172
# )
173173

174+
model.test_dc = DataCollector(
175+
model=model,
176+
agent_reporters={"wealth": lambda m: 1}
177+
)
178+
assert model.test_dc is not None
179+
174180
with pytest.raises(
175181
TypeError,
176-
match="Agent reporter for 'wealth' must be a string",
182+
match="Agent reporter keys must be strings",
177183
):
178184
model.test_dc = DataCollector(
179-
model=model,
180-
agent_reporters={"wealth": lambda m: 1}, # This is now illegal
185+
model=model,
186+
agent_reporters={123: "wealth"}
187+
)
188+
189+
with pytest.raises(
190+
TypeError,
191+
match="must be either a string .* or a callable",
192+
):
193+
model.test_dc = DataCollector(
194+
model=model,
195+
agent_reporters={"wealth": 123} # This is illegal
181196
)
182197

183198
with pytest.raises(
@@ -729,3 +744,47 @@ def test_batch_save(self, fix2_model):
729744
assert sorted(agent_df_step2_batch1["wealth"].to_list()) == sorted(
730745
expected_wealth_s2b1
731746
)
747+
748+
def test_collect_no_agentsets_list(self, fix1_model, caplog):
749+
"""Tests that the collector logs an error and exits gracefully if _agentsets is missing."""
750+
model = fix1_model
751+
del model.sets._agentsets
752+
753+
dc = DataCollector(model=model, agent_reporters={"wealth": "wealth"})
754+
dc.collect()
755+
756+
assert "could not find '_agentsets'" in caplog.text
757+
assert dc.data["agent"].shape == (0, 0)
758+
759+
def test_collect_agent_set_no_df(self, fix1_model, caplog):
760+
"""Tests that the collector logs a warning and skips a set if it has no .df attribute."""
761+
762+
class NoDfSet:
763+
def __init__(self):
764+
self.__class__ = type("NoDfSet", (object,), {})
765+
766+
fix1_model.sets._agentsets.append(NoDfSet())
767+
768+
fix1_model.dc = DataCollector(fix1_model, agent_reporters={"wealth": "wealth", "age": "age"})
769+
fix1_model.dc.collect()
770+
771+
assert "has no 'df' attribute" in caplog.text
772+
assert fix1_model.dc.data["agent"].shape == (12, 7)
773+
774+
def test_collect_df_no_unique_id(self, fix1_model, caplog):
775+
"""Tests that the collector logs a warning and skips a set if its df has no unique_id."""
776+
bad_set = fix1_model.sets._agentsets[0]
777+
bad_set._df = bad_set._df.drop("unique_id")
778+
779+
fix1_model.dc = DataCollector(fix1_model, agent_reporters={"wealth": "wealth", "age": "age"})
780+
fix1_model.dc.collect()
781+
782+
assert "has no 'unique_id' column" in caplog.text
783+
assert fix1_model.dc.data["agent"].shape == (8, 7)
784+
785+
def test_collect_no_matching_reporters(self, fix1_model):
786+
"""Tests that the collector returns an empty frame if no reporters match any columns."""
787+
fix1_model.dc = DataCollector(fix1_model, agent_reporters={"baz": "foo", "qux": "bar"})
788+
fix1_model.dc.collect()
789+
790+
assert fix1_model.dc.data["agent"].shape == (0, 0)

0 commit comments

Comments
 (0)