Skip to content

Commit 67c8d1e

Browse files
authored
refactor: optimize agent presence checks by using implode() for unique_id comparisons (#178)
1 parent dccc834 commit 67c8d1e

File tree

5 files changed

+34
-27
lines changed

5 files changed

+34
-27
lines changed

examples/sugarscape_ig/ss_polars/agents.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
def eat(self):
3838
# Only consider cells currently occupied by agents of this set
3939
cells = self.space.cells.filter(pl.col("agent_id").is_not_null())
40-
mask_in_set = cells["agent_id"].is_in(self.index)
40+
mask_in_set = cells["agent_id"].is_in(self.index.implode())
4141
if mask_in_set.any():
4242
cells = cells.filter(mask_in_set)
4343
ids = cells["agent_id"]
@@ -201,7 +201,7 @@ def get_best_moves(self, neighborhood: pl.DataFrame):
201201
)
202202
if len(best_moves) > 0:
203203
condition = condition | pl.col("blocking_agent_id").is_in(
204-
best_moves["agent_id_center"]
204+
best_moves["agent_id_center"].implode()
205205
)
206206

207207
condition = condition & (pl.col("priority") == 1)
@@ -212,7 +212,9 @@ def get_best_moves(self, neighborhood: pl.DataFrame):
212212

213213
# Remove agents that have already moved
214214
neighborhood = neighborhood.filter(
215-
~pl.col("agent_id_center").is_in(best_moves["agent_id_center"])
215+
~pl.col("agent_id_center").is_in(
216+
best_moves["agent_id_center"].implode()
217+
)
216218
)
217219

218220
# Remove cells that have been already selected

mesa_frames/abstract/space.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@ def _place_or_move_agents_to_cells(
984984

985985
if __debug__:
986986
# Check ids presence in model using public API
987-
b_contained = agents.is_in(self.model.sets.ids)
987+
b_contained = agents.is_in(self.model.sets.ids.implode())
988988
if (isinstance(b_contained, Series) and not b_contained.all()) or (
989989
isinstance(b_contained, bool) and not b_contained
990990
):
@@ -1610,7 +1610,7 @@ def remove_agents(
16101610

16111611
if __debug__:
16121612
# Check ids presence in model via public ids
1613-
b_contained = agents.is_in(obj.model.sets.ids)
1613+
b_contained = agents.is_in(obj.model.sets.ids.implode())
16141614
if (isinstance(b_contained, Series) and not b_contained.all()) or (
16151615
isinstance(b_contained, bool) and not b_contained
16161616
):
@@ -1792,7 +1792,7 @@ def _get_df_coords(
17921792
if agents is not None:
17931793
agents = self._get_ids_srs(agents)
17941794
# Check ids presence in model
1795-
b_contained = agents.is_in(self.model.sets.ids)
1795+
b_contained = agents.is_in(self.model.sets.ids.implode())
17961796
if (isinstance(b_contained, Series) and not b_contained.all()) or (
17971797
isinstance(b_contained, bool) and not b_contained
17981798
):
@@ -1872,7 +1872,7 @@ def _place_or_move_agents(
18721872
warn("Some agents are already present in the grid", RuntimeWarning)
18731873

18741874
# Check if agents are present in the model using the public ids
1875-
b_contained = agents.is_in(self.model.sets.ids)
1875+
b_contained = agents.is_in(self.model.sets.ids.implode())
18761876
if (isinstance(b_contained, Series) and not b_contained.all()) or (
18771877
isinstance(b_contained, bool) and not b_contained
18781878
):

mesa_frames/concrete/agentset.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,11 @@ def contains(
232232
agents: PolarsIdsLike,
233233
) -> bool | pl.Series:
234234
if isinstance(agents, pl.Series):
235-
return agents.is_in(self._df["unique_id"])
235+
return agents.is_in(self._df["unique_id"].implode())
236236
elif isinstance(agents, Collection) and not isinstance(agents, str):
237-
return pl.Series(agents, dtype=pl.UInt64).is_in(self._df["unique_id"])
237+
return pl.Series(agents, dtype=pl.UInt64).is_in(
238+
self._df["unique_id"].implode()
239+
)
238240
else:
239241
return agents in self._df["unique_id"]
240242

@@ -322,7 +324,7 @@ def remove(self, agents: PolarsIdsLike | AgentMask, inplace: bool = True) -> Sel
322324
# Normalize to Series of unique_ids
323325
ids = obj._df_index(obj._get_masked_df(agents), "unique_id")
324326
# Validate presence
325-
if not ids.is_in(obj._df["unique_id"]).all():
327+
if not ids.is_in(obj._df["unique_id"].implode()).all():
326328
raise KeyError("Some 'unique_id' of mask are not present in this AgentSet.")
327329
# Remove by ids
328330
return obj._discard(ids)
@@ -396,8 +398,8 @@ def select(
396398
if filter_func:
397399
mask = mask & filter_func(obj)
398400
if n is not None:
399-
mask = (obj._df["unique_id"]).is_in(
400-
obj._df.filter(mask).sample(n)["unique_id"]
401+
mask = obj._df["unique_id"].is_in(
402+
obj._df.filter(mask).sample(n)["unique_id"].implode()
401403
)
402404
if negate:
403405
mask = mask.not_()
@@ -456,7 +458,9 @@ def _concatenate_agentsets(
456458
for obj in iter(agentsets):
457459
# Remove agents that are already in the final DataFrame
458460
final_dfs.append(
459-
obj._df.filter(pl.col("unique_id").is_in(final_indices).not_())
461+
obj._df.filter(
462+
pl.col("unique_id").is_in(final_indices.implode()).not_()
463+
)
460464
)
461465
# Add the indices of the active agents of current AgentSet
462466
final_active_indices.append(obj._df.filter(obj._mask)["unique_id"])
@@ -476,13 +480,13 @@ def _concatenate_agentsets(
476480
final_active_index = pl.concat(
477481
[obj._df.filter(obj._mask)["unique_id"] for obj in agentsets]
478482
)
479-
final_mask = final_df["unique_id"].is_in(final_active_index)
483+
final_mask = final_df["unique_id"].is_in(final_active_index.implode())
480484
self._df = final_df
481485
self._mask = final_mask
482486
# If some ids were removed in the do-method, we need to remove them also from final_df
483487
if not isinstance(original_masked_index, type(None)):
484488
ids_to_remove = original_masked_index.filter(
485-
original_masked_index.is_in(self._df["unique_id"]).not_()
489+
original_masked_index.is_in(self._df["unique_id"].implode()).not_()
486490
)
487491
if not ids_to_remove.is_empty():
488492
self.remove(ids_to_remove, inplace=True)
@@ -499,7 +503,7 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series:
499503
and len(mask) == len(self._df)
500504
):
501505
return mask
502-
return self._df["unique_id"].is_in(mask)
506+
return self._df["unique_id"].is_in(mask.implode())
503507

504508
if isinstance(mask, pl.Expr):
505509
return mask
@@ -532,13 +536,13 @@ def _get_masked_df(
532536
):
533537
return self._df.filter(mask)
534538
elif isinstance(mask, pl.DataFrame):
535-
if not mask["unique_id"].is_in(self._df["unique_id"]).all():
539+
if not mask["unique_id"].is_in(self._df["unique_id"].implode()).all():
536540
raise KeyError(
537541
"Some 'unique_id' of mask are not present in DataFrame 'unique_id'."
538542
)
539543
return mask.select("unique_id").join(self._df, on="unique_id", how="left")
540544
elif isinstance(mask, pl.Series):
541-
if not mask.is_in(self._df["unique_id"]).all():
545+
if not mask.is_in(self._df["unique_id"].implode()).all():
542546
raise KeyError(
543547
"Some 'unique_id' of mask are not present in DataFrame 'unique_id'."
544548
)
@@ -553,7 +557,7 @@ def _get_masked_df(
553557
mask_series = pl.Series(mask, dtype=pl.UInt64)
554558
else:
555559
mask_series = pl.Series([mask], dtype=pl.UInt64)
556-
if not mask_series.is_in(self._df["unique_id"]).all():
560+
if not mask_series.is_in(self._df["unique_id"].implode()).all():
557561
raise KeyError(
558562
"Some 'unique_id' of mask are not present in DataFrame 'unique_id'."
559563
)
@@ -585,12 +589,13 @@ def _discard(self, ids: PolarsIdsLike) -> Self:
585589
def _update_mask(
586590
self, original_active_indices: pl.Series, new_indices: pl.Series | None = None
587591
) -> None:
592+
original_active = original_active_indices.implode()
588593
if new_indices is not None:
589-
self._mask = self._df["unique_id"].is_in(
590-
original_active_indices
591-
) | self._df["unique_id"].is_in(new_indices)
594+
self._mask = self._df["unique_id"].is_in(original_active) | self._df[
595+
"unique_id"
596+
].is_in(new_indices.implode())
592597
else:
593-
self._mask = self._df["unique_id"].is_in(original_active_indices)
598+
self._mask = self._df["unique_id"].is_in(original_active)
594599

595600
def __getattr__(self, key: str) -> Any:
596601
if key == "name":

mesa_frames/concrete/agentsetregistry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def _check_ids_presence(self, other: list[AgentSet]) -> pl.DataFrame:
575575
[
576576
presence_df,
577577
(
578-
new_ids.is_in(presence_df["unique_id"])
578+
new_ids.is_in(presence_df["unique_id"].implode())
579579
.to_frame("present")
580580
.with_columns(unique_id=new_ids)
581581
.select(["unique_id", "present"])

mesa_frames/concrete/mixin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _df_contains(
206206
column: str,
207207
values: Collection[Any],
208208
) -> pl.Series:
209-
return pl.Series("contains", values).is_in(df[column])
209+
return pl.Series("contains", values).is_in(df[column].implode())
210210

211211
def _df_div(
212212
self,
@@ -290,7 +290,7 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series:
290290
):
291291
return mask
292292
assert isinstance(index_cols, str)
293-
return df[index_cols].is_in(mask)
293+
return df[index_cols].is_in(mask.implode())
294294

295295
def bool_mask_from_df(mask: pl.DataFrame) -> pl.Series:
296296
assert index_cols, list[str]
@@ -632,7 +632,7 @@ def _srs_contains(
632632
) -> pl.Series:
633633
if not isinstance(values, Collection):
634634
values = [values]
635-
return pl.Series(values).is_in(pl.Series(srs))
635+
return pl.Series(values).is_in(pl.Series(srs).implode())
636636

637637
def _srs_range(
638638
self,

0 commit comments

Comments
 (0)