@@ -144,13 +144,33 @@ def actinf_graph(agent_network):
144144class GridAgent ():
145145 def __init__ (self , grid_len , num_agents , grid_dim = 2 ) -> None :
146146 self .grid = self .get_grid (grid_len , grid_dim )
147+ self .grid_dim = grid_dim
148+ self .no_actions = 2 * grid_dim + 1
147149 self .agents = self .init_agents (num_agents )
148150
149151 def get_grid (self , grid_len , grid_dim ):
150152 g = list (itertools .product (range (grid_len ), repeat = grid_dim ))
151153 for i , p in enumerate (g ):
152154 g [i ] += (0 ,)
153155 return g
156+
157+ def move_grid (self , agent , chosen_action ):
158+ no_actions = 2 * self .grid_dim
159+ state = list (agent .env_state )
160+ new_state = state .copy ()
161+
162+ # here
163+
164+ if chosen_action == 0 : # STAY
165+ new_state = state
166+ else :
167+ if chosen_action % 2 == 1 :
168+ index = (chosen_action + 1 ) / 2
169+ new_state [index ] = state [index ] - 1 if state [index ] > 0 else state [index ]
170+ elif chosen_action % 2 == 0 :
171+ index = chosen_action / 2
172+ new_state [index ] = state [index ] + 1 if state [index ] < agent .border else state [index ]
173+ return new_state
154174
155175 def init_agents (self , no_agents ):
156176 # create a dict of agents
@@ -282,4 +302,4 @@ def move_3d(self, agent, chosen_action):
282302 elif chosen_action == 6 : # STAY
283303 Y_new , X_new , Z_new = Y , X , Z
284304
285- return (X_new , Y_new , Z_new )
305+ return (X_new , Y_new , Z_new )
0 commit comments