@@ -54,7 +54,6 @@ def step(self):
5454 self.dc.flush()
5555"""
5656
57- from unittest import result
5857import polars as pl
5958import boto3
6059from urllib .parse import urlparse
@@ -65,14 +64,15 @@ def step(self):
6564from collections .abc import Callable
6665from mesa_frames import Model
6766from psycopg2 .extensions import connection
67+ import logging
6868
6969
7070class DataCollector (AbstractDataCollector ):
7171 def __init__ (
7272 self ,
7373 model : Model ,
7474 model_reporters : dict [str , Callable ] | None = None ,
75- agent_reporters : dict [str , str | Callable ] | None = None ,
75+ agent_reporters : dict [str , str ] | None = None ,
7676 trigger : Callable [[Any ], bool ] | None = None ,
7777 reset_memory : bool = True ,
7878 storage : Literal [
@@ -106,6 +106,14 @@ def __init__(
106106 max_worker : int
107107 Maximum number of worker threads used for flushing collected data asynchronously
108108 """
109+ if agent_reporters :
110+ for key , value in agent_reporters .items ():
111+ if not isinstance (value , str ):
112+ raise TypeError (
113+ f"Agent reporter for '{ key } ' must be a string (the column name), "
114+ f"not a { type (value )} . Callable reporters are not supported for agents."
115+ )
116+
109117 super ().__init__ (
110118 model = model ,
111119 model_reporters = model_reporters ,
@@ -173,99 +181,73 @@ def _collect_model_reporters(self, current_model_step: int, batch_id: int):
173181
174182 def _collect_agent_reporters (self , current_model_step : int , batch_id : int ):
175183 """
176- Collect agent-level data using the agent_reporters, including unique agent IDs .
184+ Collect agent-level data using the agent_reporters.
177185
178- Constructs a LazyFrame with one column per reporter and includes:
179- - agent_id : unique identifier for each agent
180- - step, seed, and batch columns for context
181- - Columns for all requested agent reporters
186+ This method iterates through all AgentSets in the model, selects the
187+ `unique_id` and the requested reporter columns from each AgentSet's
188+ DataFrame, adds an `agent_type` column, and concatenates them
189+ into a single "long" format LazyFrame.
182190 """
183191 all_agent_frames = []
192+ reporter_map = self ._agent_reporters
193+
194+ try :
195+ agent_sets_list = self ._model .sets ._agentsets
196+ except AttributeError :
197+ logging .error (
198+ "DataCollector could not find '_agentsets' attribute on model.sets. "
199+ "Agent data collection will be skipped."
200+ )
201+ return
184202
185- for col_name , reporter in self ._agent_reporters .items ():
186- if isinstance (reporter , str ):
187- agent_set = self ._model .sets [reporter ]
203+ for agent_set in agent_sets_list :
204+ if not hasattr (agent_set , "df" ):
205+ logging .warning (
206+ f"AgentSet { agent_set .__class__ .__name__ } has no 'df' attribute. Skipping."
207+ )
208+ continue
188209
189- if hasattr (agent_set , "df" ):
190- df = agent_set .df .select (["id" , col_name ]).rename (
191- {"id" : "agent_id" }
192- )
193- elif hasattr (agent_set , "to_polars" ):
194- df = (
195- agent_set .to_polars ()
196- .select (["id" , col_name ])
197- .rename ({"id" : "agent_id" })
198- )
199- else :
200- records = []
201- for agent in agent_set .values ():
202- agent_id = getattr (
203- agent , "unique_id" , getattr (agent , "id" , None )
204- )
205- records .append (
206- {
207- "agent_id" : agent_id ,
208- col_name : getattr (agent , col_name , None ),
209- }
210- )
211- df = pl .DataFrame (records )
210+ agent_df = agent_set .df .lazy ()
211+ agent_type = agent_set .__class__ .__name__
212+ available_cols = agent_df .columns
212213
213- else :
214- result = reporter (self ._model )
215-
216- ## Case 1: already a DataFrame
217- if isinstance (result , pl .DataFrame ):
218- df = result
219- ## Case 2: dict or list -> convert
220- elif isinstance (result , dict ):
221- df = pl .DataFrame ([result ])
222- elif isinstance (result , list ):
223- df = pl .DataFrame (result )
224- else :
225- ## Case 3: scalar or callable reporter
226- if hasattr (self ._model , "agents" ):
227- records = []
228- for agent in self ._model .agents :
229- agent_id = getattr (
230- agent , "unique_id" , getattr (agent , "id" , None )
231- )
232- value = getattr (
233- agent ,
234- col_name ,
235- result if not callable (result ) else None ,
236- )
237- records .append ({"agent_id" : agent_id , col_name : value })
238- df = pl .DataFrame (records )
239- else :
240- df = pl .DataFrame ([{col_name : result }])
241-
242- ## Ensure agent_id exists
243- if "agent_id" not in df .columns :
244- df = df .with_columns (pl .lit (None ).alias ("agent_id" ))
245-
246- ## Add meta columns
247- df = df .with_columns (
248- [
249- pl .lit (current_model_step ).alias ("step" ),
250- pl .lit (str (self .seed )).alias ("seed" ),
251- pl .lit (batch_id ).alias ("batch" ),
252- ]
253- )
254- all_agent_frames .append (df )
255-
256- if all_agent_frames :
257- merged_df = all_agent_frames [0 ]
258- for next_df in all_agent_frames [1 :]:
259- if "agent_id" not in next_df .columns :
260- continue
261- merged_df = merged_df .join (
262- next_df , on = ["agent_id" , "step" , "seed" , "batch" ], how = "outer"
214+ if "unique_id" not in available_cols :
215+ logging .warning (
216+ f"AgentSet { agent_type } 'df' has no 'unique_id' column. Skipping."
263217 )
218+ continue
219+
220+ cols_to_select = [pl .col ("unique_id" )]
221+
222+ for final_name , source_col in reporter_map .items ():
223+ if source_col in available_cols :
224+ ## Add the column, aliasing it if the key is different
225+ cols_to_select .append (pl .col (source_col ).alias (final_name ))
226+
227+ ## Only proceed if we have more than just unique_id
228+ if len (cols_to_select ) > 1 :
229+ set_frame = agent_df .select (cols_to_select )
230+ ## Add the agent_type column
231+ set_frame = set_frame .with_columns (
232+ pl .lit (agent_type ).alias ("agent_type" )
233+ )
234+ all_agent_frames .append (set_frame )
264235
265- agent_lazy_frame = merged_df .lazy ()
266- self ._frames .append (
267- ("agent" , current_model_step , batch_id , agent_lazy_frame )
268- )
236+ if not all_agent_frames :
237+ return
238+
239+ ## Combine all agent set DataFrames into one
240+ final_agent_frame = pl .concat (all_agent_frames , how = "diagonal_relaxed" )
241+
242+ ## Add metadata and append
243+ final_agent_frame = final_agent_frame .with_columns (
244+ [
245+ pl .lit (current_model_step ).alias ("step" ),
246+ pl .lit (str (self .seed )).alias ("seed" ),
247+ pl .lit (batch_id ).alias ("batch" ),
248+ ]
249+ )
250+ self ._frames .append (("agent" , current_model_step , batch_id , final_agent_frame ))
269251
270252 @property
271253 def data (self ) -> dict [str , pl .DataFrame ]:
@@ -534,13 +516,20 @@ def _validate_reporter_table_columns(
534516 If any expected columns are missing from the table.
535517 """
536518 expected_columns = set ()
519+
520+ ## Add columns required for the new long agent format
521+ if table_name == "agent_data" :
522+ expected_columns .add ("unique_id" )
523+ expected_columns .add ("agent_type" )
524+
525+ ## Add all keys from the reporter dict
537526 for col_name , required_column in reporter .items ():
538- if isinstance (required_column , str ):
539- for k , v in self ._model .sets [required_column ].items ():
540- expected_columns .add (
541- (col_name + "_" + str (k .__class__ .__name__ )).lower ()
542- )
527+ if table_name == "agent_data" :
528+ if isinstance (required_column , str ):
529+ expected_columns .add (col_name .lower ())
530+ ## Callables are not supported for agents
543531 else :
532+ ## For model, all reporters are callable
544533 expected_columns .add (col_name .lower ())
545534
546535 query = f"""
@@ -560,6 +549,7 @@ def _validate_reporter_table_columns(
560549 required_columns = {
561550 "step" : "Integer" ,
562551 "seed" : "Varchar" ,
552+ "batch" : "Integer"
563553 }
564554
565555 missing_required = {
@@ -606,4 +596,4 @@ def _execute_query_with_result(self, conn: connection, query: str) -> list[tuple
606596 """
607597 with conn .cursor () as cur :
608598 cur .execute (query )
609- return cur .fetchall ()
599+ return cur .fetchall ()
0 commit comments