Skip to content

Commit ef6efb3

Browse files
committed
returns Polar Dataframe.
2 parents 6806158 + f1d2e1c commit ef6efb3

File tree

1 file changed

+38
-14
lines changed

1 file changed

+38
-14
lines changed

mesa_frames/concrete/datacollector.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,27 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
187187
agent_set = self._model.sets[reporter]
188188

189189
if hasattr(agent_set, "df"):
190-
df = agent_set.df.select(["id", col_name]).rename({"id": "agent_id"})
190+
df = agent_set.df.select(["id", col_name]).rename(
191+
{"id": "agent_id"}
192+
)
191193
elif hasattr(agent_set, "to_polars"):
192-
df = agent_set.to_polars().select(["id", col_name]).rename({"id": "agent_id"})
194+
df = (
195+
agent_set.to_polars()
196+
.select(["id", col_name])
197+
.rename({"id": "agent_id"})
198+
)
193199
else:
194200
records = []
195201
for agent in agent_set.values():
196-
agent_id = getattr(agent, "unique_id", getattr(agent, "id", None))
197-
records.append({"agent_id": agent_id, col_name: getattr(agent, col_name, None)})
202+
agent_id = getattr(
203+
agent, "unique_id", getattr(agent, "id", None)
204+
)
205+
records.append(
206+
{
207+
"agent_id": agent_id,
208+
col_name: getattr(agent, col_name, None),
209+
}
210+
)
198211
df = pl.DataFrame(records)
199212

200213
else:
@@ -213,8 +226,14 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
213226
if hasattr(self._model, "agents"):
214227
records = []
215228
for agent in self._model.agents:
216-
agent_id = getattr(agent, "unique_id", getattr(agent, "id", None))
217-
value = getattr(agent, col_name, result if not callable(result) else None)
229+
agent_id = getattr(
230+
agent, "unique_id", getattr(agent, "id", None)
231+
)
232+
value = getattr(
233+
agent,
234+
col_name,
235+
result if not callable(result) else None,
236+
)
218237
records.append({"agent_id": agent_id, col_name: value})
219238
df = pl.DataFrame(records)
220239
else:
@@ -225,23 +244,28 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
225244
df = df.with_columns(pl.lit(None).alias("agent_id"))
226245

227246
## Add meta columns
228-
df = df.with_columns([
229-
pl.lit(current_model_step).alias("step"),
230-
pl.lit(str(self.seed)).alias("seed"),
231-
pl.lit(batch_id).alias("batch"),
232-
])
247+
df = df.with_columns(
248+
[
249+
pl.lit(current_model_step).alias("step"),
250+
pl.lit(str(self.seed)).alias("seed"),
251+
pl.lit(batch_id).alias("batch"),
252+
]
253+
)
233254
all_agent_frames.append(df)
234255

235256
if all_agent_frames:
236257
merged_df = all_agent_frames[0]
237258
for next_df in all_agent_frames[1:]:
238259
if "agent_id" not in next_df.columns:
239260
continue
240-
merged_df = merged_df.join(next_df, on=["agent_id", "step", "seed", "batch"], how="outer")
261+
merged_df = merged_df.join(
262+
next_df, on=["agent_id", "step", "seed", "batch"], how="outer"
263+
)
241264

242265
agent_lazy_frame = merged_df.lazy()
243-
self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame))
244-
266+
self._frames.append(
267+
("agent", current_model_step, batch_id, agent_lazy_frame)
268+
)
245269

246270
@property
247271
def data(self) -> dict[str, pl.DataFrame]:

0 commit comments

Comments
 (0)