@@ -50,15 +50,17 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type):
5050from abc import abstractmethod
5151from collections .abc import Callable , Collection , Sequence , Sized
5252from itertools import product
53- from typing import TYPE_CHECKING , Literal
53+ from typing import TYPE_CHECKING , Literal , TypeVar , Union
5454from warnings import warn
5555
5656import numpy as np
5757import polars as pl
5858from numpy .random import Generator
5959from typing_extensions import Any , Self
6060
61- from mesa_frames import AgentsDF
61+
62+ from mesa_frames .concrete .polars .agentset import AgentSetPolars
63+ from mesa_frames .concrete .agents import AgentsDF
6264from mesa_frames .abstract .agents import AgentContainer , AgentSetDF
6365from mesa_frames .abstract .mixin import CopyMixin , DataFrameMixin
6466from mesa_frames .types_ import (
@@ -79,6 +81,9 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type):
7981
8082ESPG = int
8183
84+
85+ AgentLike = Union [AgentSetPolars , pl .DataFrame ]
86+
8287if TYPE_CHECKING :
8388 from mesa_frames .concrete .model import ModelDF
8489
@@ -1036,6 +1041,107 @@ def __repr__(self) -> str:
10361041 def __str__ (self ) -> str :
10371042 return f"{ self .__class__ .__name__ } \n { str (self .cells )} "
10381043
1044+ def move_to (
1045+ self ,
1046+ agents : AgentLike ,
1047+ attr_names : str | list [str ],
1048+ rank_order : str | list [str ] = "max" ,
1049+ radius : int | pl .Series = None ,
1050+ include_center : bool = True ,
1051+ shuffle : bool = True
1052+ ) -> None :
1053+ if isinstance (attr_names , str ):
1054+ attr_names = [attr_names ]
1055+ if isinstance (rank_order , str ):
1056+ rank_order = [rank_order ] * len (attr_names )
1057+ if len (attr_names ) != len (rank_order ):
1058+ raise ValueError ("attr_names and rank_order must have the same length" )
1059+ if radius is None :
1060+ if "vision" in agents .columns :
1061+ radius = agents ["vision" ]
1062+ else :
1063+ raise ValueError ("radius must be specified if agents do not have a 'vision' attribute" )
1064+ neighborhood = self .get_neighborhood (
1065+ radius = radius ,
1066+ agents = agents ,
1067+ include_center = include_center
1068+ )
1069+ neighborhood = neighborhood .join (self .cells , on = ["dim_0" , "dim_1" ])
1070+ neighborhood = neighborhood .with_columns (
1071+ agent_id_center = neighborhood .join (
1072+ agents .pos ,
1073+ left_on = ["dim_0_center" , "dim_1_center" ],
1074+ right_on = ["dim_0" , "dim_1" ],
1075+ )["unique_id" ]
1076+ )
1077+ if shuffle :
1078+ agent_order = (
1079+ neighborhood
1080+ .unique (subset = ["agent_id_center" ], keep = "first" )
1081+ .select ("agent_id_center" )
1082+ .sample (fraction = 1.0 , seed = self .model .random .integers (0 , 2 ** 31 - 1 ))
1083+ .with_row_index ("agent_order" )
1084+ )
1085+ else :
1086+ agent_order = (
1087+ neighborhood
1088+ .unique (subset = ["agent_id_center" ], keep = "first" , maintain_order = True )
1089+ .with_row_index ("agent_order" )
1090+ .select (["agent_id_center" , "agent_order" ])
1091+ )
1092+ neighborhood = neighborhood .join (agent_order , on = "agent_id_center" )
1093+ sort_cols = []
1094+ sort_desc = []
1095+ for attr , order in zip (attr_names , rank_order ):
1096+ sort_cols .append (attr )
1097+ sort_desc .append (order .lower () == "max" )
1098+ neighborhood = neighborhood .sort (
1099+ sort_cols + ["radius" , "dim_0" , "dim_1" ],
1100+ descending = sort_desc + [False , False , False ]
1101+ )
1102+ neighborhood = neighborhood .join (
1103+ agent_order .select (
1104+ pl .col ("agent_id_center" ).alias ("agent_id" ),
1105+ pl .col ("agent_order" ).alias ("blocking_agent_order" ),
1106+ ),
1107+ on = "agent_id" ,
1108+ how = "left" ,
1109+ ).rename ({"agent_id" : "blocking_agent_id" })
1110+ best_moves = pl .DataFrame ()
1111+ while len (best_moves ) < len (agents ):
1112+ neighborhood = neighborhood .with_columns (
1113+ priority = pl .col ("agent_order" ).cum_count ().over (["dim_0" , "dim_1" ])
1114+ )
1115+ new_best_moves = (
1116+ neighborhood .group_by ("agent_id_center" , maintain_order = True )
1117+ .first ()
1118+ .unique (subset = ["dim_0" , "dim_1" ], keep = "first" , maintain_order = True )
1119+ )
1120+ condition = pl .col ("blocking_agent_id" ).is_null () | (
1121+ pl .col ("blocking_agent_id" ) == pl .col ("agent_id_center" )
1122+ )
1123+ if len (best_moves ) > 0 :
1124+ condition = condition | pl .col ("blocking_agent_id" ).is_in (
1125+ best_moves ["agent_id_center" ]
1126+ )
1127+ condition = condition & (pl .col ("priority" ) == 1 )
1128+ new_best_moves = new_best_moves .filter (condition )
1129+ if len (new_best_moves ) == 0 :
1130+ break
1131+ best_moves = pl .concat ([best_moves , new_best_moves ])
1132+ neighborhood = neighborhood .filter (
1133+ ~ pl .col ("agent_id_center" ).is_in (best_moves ["agent_id_center" ])
1134+ )
1135+ neighborhood = neighborhood .join (
1136+ best_moves .select (["dim_0" , "dim_1" ]), on = ["dim_0" , "dim_1" ], how = "anti"
1137+ )
1138+ if len (best_moves ) > 0 :
1139+ self .move_agents (
1140+ best_moves .sort ("agent_order" )["agent_id_center" ],
1141+ best_moves .sort ("agent_order" ).select (["dim_0" , "dim_1" ])
1142+ )
1143+
1144+
10391145 @property
10401146 def cells (self ) -> DataFrame :
10411147 """
0 commit comments