Skip to content

Commit 5c1525f

Browse files
committed
Fix: Add unique_id and agent_type to agent reports
1 parent ef6efb3 commit 5c1525f

File tree

2 files changed

+281
-500
lines changed

2 files changed

+281
-500
lines changed

mesa_frames/concrete/datacollector.py

Lines changed: 83 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def step(self):
5454
self.dc.flush()
5555
"""
5656

57-
from unittest import result
5857
import polars as pl
5958
import boto3
6059
from urllib.parse import urlparse
@@ -65,14 +64,15 @@ def step(self):
6564
from collections.abc import Callable
6665
from mesa_frames import Model
6766
from psycopg2.extensions import connection
67+
import logging
6868

6969

7070
class DataCollector(AbstractDataCollector):
7171
def __init__(
7272
self,
7373
model: Model,
7474
model_reporters: dict[str, Callable] | None = None,
75-
agent_reporters: dict[str, str | Callable] | None = None,
75+
agent_reporters: dict[str, str] | None = None,
7676
trigger: Callable[[Any], bool] | None = None,
7777
reset_memory: bool = True,
7878
storage: Literal[
@@ -106,6 +106,14 @@ def __init__(
106106
max_worker : int
107107
Maximum number of worker threads used for flushing collected data asynchronously
108108
"""
109+
if agent_reporters:
110+
for key, value in agent_reporters.items():
111+
if not isinstance(value, str):
112+
raise TypeError(
113+
f"Agent reporter for '{key}' must be a string (the column name), "
114+
f"not a {type(value)}. Callable reporters are not supported for agents."
115+
)
116+
109117
super().__init__(
110118
model=model,
111119
model_reporters=model_reporters,
@@ -173,99 +181,73 @@ def _collect_model_reporters(self, current_model_step: int, batch_id: int):
173181

174182
def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
175183
"""
176-
Collect agent-level data using the agent_reporters, including unique agent IDs.
184+
Collect agent-level data using the agent_reporters.
177185
178-
Constructs a LazyFrame with one column per reporter and includes:
179-
- agent_id : unique identifier for each agent
180-
- step, seed, and batch columns for context
181-
- Columns for all requested agent reporters
186+
This method iterates through all AgentSets in the model, selects the
187+
`unique_id` and the requested reporter columns from each AgentSet's
188+
DataFrame, adds an `agent_type` column, and concatenates them
189+
into a single "long" format LazyFrame.
182190
"""
183191
all_agent_frames = []
192+
reporter_map = self._agent_reporters
193+
194+
try:
195+
agent_sets_list = self._model.sets._agentsets
196+
except AttributeError:
197+
logging.error(
198+
"DataCollector could not find '_agentsets' attribute on model.sets. "
199+
"Agent data collection will be skipped."
200+
)
201+
return
184202

185-
for col_name, reporter in self._agent_reporters.items():
186-
if isinstance(reporter, str):
187-
agent_set = self._model.sets[reporter]
203+
for agent_set in agent_sets_list:
204+
if not hasattr(agent_set, "df"):
205+
logging.warning(
206+
f"AgentSet {agent_set.__class__.__name__} has no 'df' attribute. Skipping."
207+
)
208+
continue
188209

189-
if hasattr(agent_set, "df"):
190-
df = agent_set.df.select(["id", col_name]).rename(
191-
{"id": "agent_id"}
192-
)
193-
elif hasattr(agent_set, "to_polars"):
194-
df = (
195-
agent_set.to_polars()
196-
.select(["id", col_name])
197-
.rename({"id": "agent_id"})
198-
)
199-
else:
200-
records = []
201-
for agent in agent_set.values():
202-
agent_id = getattr(
203-
agent, "unique_id", getattr(agent, "id", None)
204-
)
205-
records.append(
206-
{
207-
"agent_id": agent_id,
208-
col_name: getattr(agent, col_name, None),
209-
}
210-
)
211-
df = pl.DataFrame(records)
210+
agent_df = agent_set.df.lazy()
211+
agent_type = agent_set.__class__.__name__
212+
available_cols = agent_df.columns
212213

213-
else:
214-
result = reporter(self._model)
215-
216-
## Case 1: already a DataFrame
217-
if isinstance(result, pl.DataFrame):
218-
df = result
219-
## Case 2: dict or list -> convert
220-
elif isinstance(result, dict):
221-
df = pl.DataFrame([result])
222-
elif isinstance(result, list):
223-
df = pl.DataFrame(result)
224-
else:
225-
## Case 3: scalar or callable reporter
226-
if hasattr(self._model, "agents"):
227-
records = []
228-
for agent in self._model.agents:
229-
agent_id = getattr(
230-
agent, "unique_id", getattr(agent, "id", None)
231-
)
232-
value = getattr(
233-
agent,
234-
col_name,
235-
result if not callable(result) else None,
236-
)
237-
records.append({"agent_id": agent_id, col_name: value})
238-
df = pl.DataFrame(records)
239-
else:
240-
df = pl.DataFrame([{col_name: result}])
241-
242-
## Ensure agent_id exists
243-
if "agent_id" not in df.columns:
244-
df = df.with_columns(pl.lit(None).alias("agent_id"))
245-
246-
## Add meta columns
247-
df = df.with_columns(
248-
[
249-
pl.lit(current_model_step).alias("step"),
250-
pl.lit(str(self.seed)).alias("seed"),
251-
pl.lit(batch_id).alias("batch"),
252-
]
253-
)
254-
all_agent_frames.append(df)
255-
256-
if all_agent_frames:
257-
merged_df = all_agent_frames[0]
258-
for next_df in all_agent_frames[1:]:
259-
if "agent_id" not in next_df.columns:
260-
continue
261-
merged_df = merged_df.join(
262-
next_df, on=["agent_id", "step", "seed", "batch"], how="outer"
214+
if "unique_id" not in available_cols:
215+
logging.warning(
216+
f"AgentSet {agent_type} 'df' has no 'unique_id' column. Skipping."
263217
)
218+
continue
219+
220+
cols_to_select = [pl.col("unique_id")]
221+
222+
for final_name, source_col in reporter_map.items():
223+
if source_col in available_cols:
224+
## Add the column, aliasing it if the key is different
225+
cols_to_select.append(pl.col(source_col).alias(final_name))
226+
227+
## Only proceed if we have more than just unique_id
228+
if len(cols_to_select) > 1:
229+
set_frame = agent_df.select(cols_to_select)
230+
## Add the agent_type column
231+
set_frame = set_frame.with_columns(
232+
pl.lit(agent_type).alias("agent_type")
233+
)
234+
all_agent_frames.append(set_frame)
264235

265-
agent_lazy_frame = merged_df.lazy()
266-
self._frames.append(
267-
("agent", current_model_step, batch_id, agent_lazy_frame)
268-
)
236+
if not all_agent_frames:
237+
return
238+
239+
## Combine all agent set DataFrames into one
240+
final_agent_frame = pl.concat(all_agent_frames, how="diagonal_relaxed")
241+
242+
## Add metadata and append
243+
final_agent_frame = final_agent_frame.with_columns(
244+
[
245+
pl.lit(current_model_step).alias("step"),
246+
pl.lit(str(self.seed)).alias("seed"),
247+
pl.lit(batch_id).alias("batch"),
248+
]
249+
)
250+
self._frames.append(("agent", current_model_step, batch_id, final_agent_frame))
269251

270252
@property
271253
def data(self) -> dict[str, pl.DataFrame]:
@@ -534,13 +516,20 @@ def _validate_reporter_table_columns(
534516
If any expected columns are missing from the table.
535517
"""
536518
expected_columns = set()
519+
520+
## Add columns required for the new long agent format
521+
if table_name == "agent_data":
522+
expected_columns.add("unique_id")
523+
expected_columns.add("agent_type")
524+
525+
## Add all keys from the reporter dict
537526
for col_name, required_column in reporter.items():
538-
if isinstance(required_column, str):
539-
for k, v in self._model.sets[required_column].items():
540-
expected_columns.add(
541-
(col_name + "_" + str(k.__class__.__name__)).lower()
542-
)
527+
if table_name == "agent_data":
528+
if isinstance(required_column, str):
529+
expected_columns.add(col_name.lower())
530+
## Callables are not supported for agents
543531
else:
532+
## For model, all reporters are callable
544533
expected_columns.add(col_name.lower())
545534

546535
query = f"""
@@ -560,6 +549,7 @@ def _validate_reporter_table_columns(
560549
required_columns = {
561550
"step": "Integer",
562551
"seed": "Varchar",
552+
"batch": "Integer"
563553
}
564554

565555
missing_required = {
@@ -606,4 +596,4 @@ def _execute_query_with_result(self, conn: connection, query: str) -> list[tuple
606596
"""
607597
with conn.cursor() as cur:
608598
cur.execute(query)
609-
return cur.fetchall()
599+
return cur.fetchall()

0 commit comments

Comments
 (0)