Skip to content

Commit 7afe600

Browse files
Added move_to function in DiscreteSpaceDF class
Added the move_to function to allow agent movement based on specified attributes and ranking orders. The function considers neighborhood radius, sorting preferences, and optional shuffling for random movement. It ensures conflict resolution using priority-based selection, optimizing movement allocation. This enhances the agent-based model's flexibility and realism.
1 parent df26a23 commit 7afe600

File tree

1 file changed

+108
-2
lines changed

1 file changed

+108
-2
lines changed

mesa_frames/abstract/space.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,17 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type):
5050
from abc import abstractmethod
5151
from collections.abc import Callable, Collection, Sequence, Sized
5252
from itertools import product
53-
from typing import TYPE_CHECKING, Literal
53+
from typing import TYPE_CHECKING, Literal, TypeVar, Union
5454
from warnings import warn
5555

5656
import numpy as np
5757
import polars as pl
5858
from numpy.random import Generator
5959
from typing_extensions import Any, Self
6060

61-
from mesa_frames import AgentsDF
61+
62+
from mesa_frames.concrete.polars.agentset import AgentSetPolars
63+
from mesa_frames.concrete.agents import AgentsDF
6264
from mesa_frames.abstract.agents import AgentContainer, AgentSetDF
6365
from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin
6466
from mesa_frames.types_ import (
@@ -79,6 +81,9 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type):
7981

8082
ESPG = int
8183

84+
85+
AgentLike = Union[AgentSetPolars, pl.DataFrame]
86+
8287
if TYPE_CHECKING:
8388
from mesa_frames.concrete.model import ModelDF
8489

@@ -1036,6 +1041,107 @@ def __repr__(self) -> str:
10361041
def __str__(self) -> str:
10371042
return f"{self.__class__.__name__}\n{str(self.cells)}"
10381043

1044+
def move_to(
1045+
self,
1046+
agents: AgentLike,
1047+
attr_names: str | list[str],
1048+
rank_order: str | list[str] = "max",
1049+
radius: int | pl.Series = None,
1050+
include_center: bool = True,
1051+
shuffle: bool = True
1052+
) -> None:
1053+
if isinstance(attr_names, str):
1054+
attr_names = [attr_names]
1055+
if isinstance(rank_order, str):
1056+
rank_order = [rank_order] * len(attr_names)
1057+
if len(attr_names) != len(rank_order):
1058+
raise ValueError("attr_names and rank_order must have the same length")
1059+
if radius is None:
1060+
if "vision" in agents.columns:
1061+
radius = agents["vision"]
1062+
else:
1063+
raise ValueError("radius must be specified if agents do not have a 'vision' attribute")
1064+
neighborhood = self.get_neighborhood(
1065+
radius=radius,
1066+
agents=agents,
1067+
include_center=include_center
1068+
)
1069+
neighborhood = neighborhood.join(self.cells, on=["dim_0", "dim_1"])
1070+
neighborhood = neighborhood.with_columns(
1071+
agent_id_center=neighborhood.join(
1072+
agents.pos,
1073+
left_on=["dim_0_center", "dim_1_center"],
1074+
right_on=["dim_0", "dim_1"],
1075+
)["unique_id"]
1076+
)
1077+
if shuffle:
1078+
agent_order = (
1079+
neighborhood
1080+
.unique(subset=["agent_id_center"], keep="first")
1081+
.select("agent_id_center")
1082+
.sample(fraction=1.0, seed=self.model.random.integers(0, 2**31-1))
1083+
.with_row_index("agent_order")
1084+
)
1085+
else:
1086+
agent_order = (
1087+
neighborhood
1088+
.unique(subset=["agent_id_center"], keep="first", maintain_order=True)
1089+
.with_row_index("agent_order")
1090+
.select(["agent_id_center", "agent_order"])
1091+
)
1092+
neighborhood = neighborhood.join(agent_order, on="agent_id_center")
1093+
sort_cols = []
1094+
sort_desc = []
1095+
for attr, order in zip(attr_names, rank_order):
1096+
sort_cols.append(attr)
1097+
sort_desc.append(order.lower() == "max")
1098+
neighborhood = neighborhood.sort(
1099+
sort_cols + ["radius", "dim_0", "dim_1"],
1100+
descending=sort_desc + [False, False, False]
1101+
)
1102+
neighborhood = neighborhood.join(
1103+
agent_order.select(
1104+
pl.col("agent_id_center").alias("agent_id"),
1105+
pl.col("agent_order").alias("blocking_agent_order"),
1106+
),
1107+
on="agent_id",
1108+
how="left",
1109+
).rename({"agent_id": "blocking_agent_id"})
1110+
best_moves = pl.DataFrame()
1111+
while len(best_moves) < len(agents):
1112+
neighborhood = neighborhood.with_columns(
1113+
priority=pl.col("agent_order").cum_count().over(["dim_0", "dim_1"])
1114+
)
1115+
new_best_moves = (
1116+
neighborhood.group_by("agent_id_center", maintain_order=True)
1117+
.first()
1118+
.unique(subset=["dim_0", "dim_1"], keep="first", maintain_order=True)
1119+
)
1120+
condition = pl.col("blocking_agent_id").is_null() | (
1121+
pl.col("blocking_agent_id") == pl.col("agent_id_center")
1122+
)
1123+
if len(best_moves) > 0:
1124+
condition = condition | pl.col("blocking_agent_id").is_in(
1125+
best_moves["agent_id_center"]
1126+
)
1127+
condition = condition & (pl.col("priority") == 1)
1128+
new_best_moves = new_best_moves.filter(condition)
1129+
if len(new_best_moves) == 0:
1130+
break
1131+
best_moves = pl.concat([best_moves, new_best_moves])
1132+
neighborhood = neighborhood.filter(
1133+
~pl.col("agent_id_center").is_in(best_moves["agent_id_center"])
1134+
)
1135+
neighborhood = neighborhood.join(
1136+
best_moves.select(["dim_0", "dim_1"]), on=["dim_0", "dim_1"], how="anti"
1137+
)
1138+
if len(best_moves) > 0:
1139+
self.move_agents(
1140+
best_moves.sort("agent_order")["agent_id_center"],
1141+
best_moves.sort("agent_order").select(["dim_0", "dim_1"])
1142+
)
1143+
1144+
10391145
@property
10401146
def cells(self) -> DataFrame:
10411147
"""

0 commit comments

Comments
 (0)