1+ from blockference .gridference import *
2+
3+
4+ class GridAgent ():
5+ def __init__ (self , grid_len , num_agents , grid_dim = 2 ) -> None :
6+ self .grid = self .get_grid (grid_len , grid_dim )
7+ self .grid_dim = grid_dim
8+ self .no_actions = 2 * grid_dim + 1
9+ self .agents = self .init_agents (num_agents )
10+
11+ def get_grid (self , grid_len , grid_dim ):
12+ g = list (itertools .product (range (grid_len ), repeat = grid_dim ))
13+ for i , p in enumerate (g ):
14+ g [i ] += (0 ,)
15+ return g
16+
17+ def move_grid (self , agent , chosen_action ):
18+ no_actions = 2 * self .grid_dim
19+ state = list (agent .env_state )
20+ new_state = state .copy ()
21+
22+ # here
23+
24+ if chosen_action == 0 : # STAY
25+ new_state = state
26+ else :
27+ if chosen_action % 2 == 1 :
28+ index = (chosen_action + 1 ) / 2
29+ new_state [index ] = state [index ] - 1 if state [index ] > 0 else state [index ]
30+ elif chosen_action % 2 == 0 :
31+ index = chosen_action / 2
32+ new_state [index ] = state [index ] + 1 if state [index ] < agent .border else state [index ]
33+ return new_state
34+
35+ def init_agents (self , no_agents ):
36+ # create a dict of agents
37+ agents = {}
38+
39+ for a in range (no_agents ):
40+ # create new agent
41+ agent = ActiveGridference (self .grid )
42+ # generate target state
43+ target = (rand .randint (0 , 9 ), rand .randint (0 , 9 ))
44+ # add target state
45+ agent .get_C (target + (0 ,))
46+ # all agents start in the same position
47+ start = (rand .randint (0 , 9 ), rand .randint (0 , 9 ))
48+ agent .get_D (start + (1 ,))
49+
50+ agents [a ] = agent
51+
52+ return agents
53+
54+ def actinf_dict (self , agents_dict , g_agent ):
55+ # list of all updates to the agents in the network
56+ agent_updates = []
57+
58+ for source , agent in agents_dict .items ():
59+
60+ policies = construct_policies ([agent .n_states ], [len (agent .E )], policy_len = agent .policy_len )
61+ # get obs_idx
62+ obs_idx = g_agent .grid .index (agent .env_state )
63+
64+ # infer_states
65+ qs_current = u .infer_states (obs_idx , agent .A , agent .prior )
66+
67+ # calc efe
68+ _G = u .calculate_G_policies (agent .A , agent .B , agent .C , qs_current , policies = policies )
69+
70+ # calc action posterior
71+ Q_pi = u .softmax (- _G )
72+ # compute the probability of each action
73+ P_u = u .compute_prob_actions (agent .E , policies , Q_pi )
74+
75+ # sample action
76+ chosen_action = u .sample (P_u )
77+
78+ # calc next prior
79+ prior = agent .B [:, :, chosen_action ].dot (qs_current )
80+
81+ # update env state
82+ # action_label = params['actions'][chosen_action]
83+
84+ current_state = self .move_2d (agent , chosen_action ) # store the new grid location
85+ agent_update = {'source' : source ,
86+ 'update_prior' : prior ,
87+ 'update_env' : current_state ,
88+ 'update_action' : chosen_action ,
89+ 'update_inference' : qs_current }
90+ agent_updates .append (agent_update )
91+
92+ return {'agent_updates' : agent_updates }
93+
94+ def move_2d (self , agent , chosen_action ):
95+ (Y , X ) = agent .env_state
96+ Y_new = Y
97+ X_new = X
98+ # here
99+
100+ if chosen_action == 0 : # UP
101+
102+ Y_new = Y - 1 if Y > 0 else Y
103+ X_new = X
104+
105+ elif chosen_action == 1 : # DOWN
106+
107+ Y_new = Y + 1 if Y < agent .border else Y
108+ X_new = X
109+
110+ elif chosen_action == 2 : # LEFT
111+ Y_new = Y
112+ X_new = X - 1 if X > 0 else X
113+
114+ elif chosen_action == 3 : # RIGHT
115+ Y_new = Y
116+ X_new = X + 1 if X < agent .border else X
117+
118+ elif chosen_action == 4 : # STAY
119+ Y_new , X_new = Y , X
120+
121+ return (X_new , Y_new )
122+
123+ def move_3d (self , agent , chosen_action ):
124+ (Y , X , Z ) = agent .env_state
125+ Y_new = Y
126+ X_new = X
127+ Z_new = Z
128+ # here
129+
130+ if chosen_action == 0 : # UP
131+
132+ Y_new = Y - 1 if Y > 0 else Y
133+ X_new = X
134+ Z_new = Z
135+
136+ elif chosen_action == 1 : # DOWN
137+
138+ Y_new = Y + 1 if Y < agent .border else Y
139+ X_new = X
140+ Z_new = Z
141+
142+ elif chosen_action == 2 : # LEFT
143+ Y_new = Y
144+ X_new = X - 1 if X > 0 else X
145+ Z_new = Z
146+
147+ elif chosen_action == 3 : # RIGHT
148+ Y_new = Y
149+ X_new = X + 1 if X < agent .border else X
150+ Z_new = Z
151+
152+ elif chosen_action == 4 : # IN
153+ X_new = X
154+ Y_new = Y
155+ Z_new = Z + 1 if Z < agent .border else Z
156+
157+ elif chosen_action == 5 : # OUT
158+ X_new = X
159+ Y_new = Y
160+ Z_new = Z - 1 if Z > agent .border else Z
161+
162+ elif chosen_action == 6 : # STAY
163+ Y_new , X_new , Z_new = Y , X , Z
164+
165+ return (X_new , Y_new , Z_new )
0 commit comments