1- import enum
1+ import enum # not currently used? Is enum used in another script? or can remove.
22import sys
3-
4- from matplotlib .pyplot import grid
5-
6- # adding tools to the system path
7- sys .path .insert (0 , '../tools/' )
8-
93from tools .model import ActiveGridference
104from tools .control import construct_policies
115import tools .utils as u
126import random as rand
137import itertools
148
9+ from matplotlib .pyplot import grid
10+
11+ # adding tools to the system path
12+ sys .path .insert (0 , '../tools/' )
13+
1514
1615def actinf_planning_single (agent , env_state , A , B , C , prior ):
17- policies = construct_policies ([agent .n_states ], [len (agent .E )], policy_len = agent .policy_len )
16+ policies = construct_policies ([agent .n_states ],
17+ [len (agent .E )],
18+ policy_len = agent .policy_len )
1819 # get obs_idx
1920 obs_idx = grid .index (env_state )
2021
@@ -34,7 +35,7 @@ def actinf_planning_single(agent, env_state, A, B, C, prior):
3435 chosen_action = u .sample (P_u )
3536
3637 # calc next prior
37- prior = B [:,:, chosen_action ].dot (qs_current )
38+ prior = B [:, :, chosen_action ].dot (qs_current )
3839
3940 # update env state
4041 # action_label = params['actions'][chosen_action]
@@ -43,42 +44,43 @@ def actinf_planning_single(agent, env_state, A, B, C, prior):
4344 Y_new = Y
4445 X_new = X
4546
46- if chosen_action == 0 : # UP
47-
47+ if chosen_action == 0 : # UP
48+
4849 Y_new = Y - 1 if Y > 0 else Y
4950 X_new = X
5051
51- elif chosen_action == 1 : # DOWN
52+ elif chosen_action == 1 : # DOWN
5253
5354 Y_new = Y + 1 if Y < agent .border else Y
5455 X_new = X
5556
56- elif chosen_action == 2 : # LEFT
57+ elif chosen_action == 2 : # LEFT
5758 Y_new = Y
5859 X_new = X - 1 if X > 0 else X
5960
60- elif chosen_action == 3 : # RIGHT
61+ elif chosen_action == 3 : # RIGHT
6162 Y_new = Y
62- X_new = X + 1 if X < agent .border else X
63+ X_new = X + 1 if X < agent .border else X
6364
64- elif chosen_action == 4 : # STAY
65- Y_new , X_new = Y , X
66-
67- current_state = (Y_new , X_new ) # store the new grid location
65+ elif chosen_action == 4 : # STAY
66+ Y_new , X_new = Y , X
67+
68+ current_state = (Y_new , X_new ) # store the new grid location
6869
6970 return {'update_prior' : prior ,
7071 'update_env' : current_state ,
7172 'update_action' : chosen_action ,
7273 'update_inference' : qs_current }
7374
75+
7476def actinf_graph (agent_network ):
7577
7678 # list of all updates to the agents in the network
7779 agent_updates = []
7880
7981 for agent in agent_network .nodes :
8082
81- policies = construct_policies ([agent_network .nodes [agent ]['agent' ].n_states ], [len (agent_network .nodes [agent ]['agent' ].E )], policy_len = agent_network .nodes [agent ]['agent' ].policy_len )
83+ policies = construct_policies ([agent_network .nodes [agent ]['agent' ].n_states ], [len (agent_network .nodes [agent ]['agent' ].E )], policy_len = agent_network .nodes [agent ]['agent' ].policy_len )
8284 # get obs_idx
8385 obs_idx = grid .index (agent_network .nodes [agent ]['env_state' ])
8486
@@ -92,12 +94,12 @@ def actinf_graph(agent_network):
9294 Q_pi = u .softmax (- _G )
9395 # compute the probability of each action
9496 P_u = u .compute_prob_actions (agent_network .nodes [agent ]['agent' ].E , policies , Q_pi )
95-
97+
9698 # sample action
9799 chosen_action = u .sample (P_u )
98100
99101 # calc next prior
100- prior = agent_network .nodes [agent ]['prior_B' ][:,:, chosen_action ].dot (qs_current )
102+ prior = agent_network .nodes [agent ]['prior_B' ][:, :, chosen_action ].dot (qs_current )
101103
102104 # update env state
103105 # action_label = params['actions'][chosen_action]
@@ -107,28 +109,28 @@ def actinf_graph(agent_network):
107109 X_new = X
108110 # here
109111
110- if chosen_action == 0 : # UP
111-
112+ if chosen_action == 0 : # UP
113+
112114 Y_new = Y - 1 if Y > 0 else Y
113115 X_new = X
114116
115- elif chosen_action == 1 : # DOWN
117+ elif chosen_action == 1 : # DOWN
116118
117119 Y_new = Y + 1 if Y < agent_network .nodes [agent ]['agent' ].border else Y
118120 X_new = X
119121
120- elif chosen_action == 2 : # LEFT
122+ elif chosen_action == 2 : # LEFT
121123 Y_new = Y
122124 X_new = X - 1 if X > 0 else X
123125
124- elif chosen_action == 3 : # RIGHT
126+ elif chosen_action == 3 : # RIGHT
125127 Y_new = Y
126- X_new = X + 1 if X < agent_network .nodes [agent ]['agent' ].border else X
128+ X_new = X + 1 if X < agent_network .nodes [agent ]['agent' ].border else X
129+
130+ elif chosen_action == 4 : # STAY
131+ Y_new , X_new = Y , X
127132
128- elif chosen_action == 4 : # STAY
129- Y_new , X_new = Y , X
130-
131- current_state = (Y_new , X_new ) # store the new grid location
133+ current_state = (Y_new , X_new ) # store the new grid location
132134 agent_update = {'source' : agent ,
133135 'update_prior' : prior ,
134136 'update_env' : current_state ,
@@ -138,6 +140,7 @@ def actinf_graph(agent_network):
138140
139141 return {'agent_updates' : agent_updates }
140142
143+
141144class GridAgent ():
142145 def __init__ (self , grid_len , num_agents , grid_dim = 2 ) -> None :
143146 self .grid = self .get_grid (grid_len , grid_dim )
@@ -157,24 +160,24 @@ def init_agents(self, no_agents):
157160 # create new agent
158161 agent = ActiveGridference (self .grid )
159162 # generate target state
160- target = (rand .randint (0 ,9 ), rand .randint (0 ,9 ))
163+ target = (rand .randint (0 , 9 ), rand .randint (0 , 9 ))
161164 # add target state
162165 agent .get_C (target + (0 ,))
163166 # all agents start in the same position
164- start = (rand .randint (0 ,9 ), rand .randint (0 ,9 ))
167+ start = (rand .randint (0 , 9 ), rand .randint (0 , 9 ))
165168 agent .get_D (start + (1 ,))
166169
167170 agents [a ] = agent
168171
169172 return agents
170-
173+
171174 def actinf_dict (self , agents_dict , g_agent ):
172175 # list of all updates to the agents in the network
173176 agent_updates = []
174177
175178 for source , agent in agents_dict .items ():
176179
177- policies = construct_policies ([agent .n_states ], [len (agent .E )], policy_len = agent .policy_len )
180+ policies = construct_policies ([agent .n_states ], [len (agent .E )], policy_len = agent .policy_len )
178181 # get obs_idx
179182 obs_idx = g_agent .grid .index (agent .env_state )
180183
@@ -193,12 +196,12 @@ def actinf_dict(self, agents_dict, g_agent):
193196 chosen_action = u .sample (P_u )
194197
195198 # calc next prior
196- prior = agent .B [:,:, chosen_action ].dot (qs_current )
199+ prior = agent .B [:, :, chosen_action ].dot (qs_current )
197200
198201 # update env state
199202 # action_label = params['actions'][chosen_action]
200-
201- current_state = self .move_2d (agent , chosen_action ) # store the new grid location
203+
204+ current_state = self .move_2d (agent , chosen_action ) # store the new grid location
202205 agent_update = {'source' : source ,
203206 'update_prior' : prior ,
204207 'update_env' : current_state ,
@@ -214,25 +217,25 @@ def move_2d(self, agent, chosen_action):
214217 X_new = X
215218 # here
216219
217- if chosen_action == 0 : # UP
218-
220+ if chosen_action == 0 : # UP
221+
219222 Y_new = Y - 1 if Y > 0 else Y
220223 X_new = X
221224
222- elif chosen_action == 1 : # DOWN
225+ elif chosen_action == 1 : # DOWN
223226
224227 Y_new = Y + 1 if Y < agent .border else Y
225228 X_new = X
226229
227- elif chosen_action == 2 : # LEFT
230+ elif chosen_action == 2 : # LEFT
228231 Y_new = Y
229232 X_new = X - 1 if X > 0 else X
230233
231- elif chosen_action == 3 : # RIGHT
234+ elif chosen_action == 3 : # RIGHT
232235 Y_new = Y
233- X_new = X + 1 if X < agent .border else X
236+ X_new = X + 1 if X < agent .border else X
234237
235- elif chosen_action == 4 : # STAY
238+ elif chosen_action == 4 : # STAY
236239 Y_new , X_new = Y , X
237240
238241 return (X_new , Y_new )
@@ -244,39 +247,39 @@ def move_3d(self, agent, chosen_action):
244247 Z_new = Z
245248 # here
246249
247- if chosen_action == 0 : # UP
248-
250+ if chosen_action == 0 : # UP
251+
249252 Y_new = Y - 1 if Y > 0 else Y
250253 X_new = X
251254 Z_new = Z
252255
253- elif chosen_action == 1 : # DOWN
256+ elif chosen_action == 1 : # DOWN
254257
255258 Y_new = Y + 1 if Y < agent .border else Y
256259 X_new = X
257260 Z_new = Z
258261
259- elif chosen_action == 2 : # LEFT
262+ elif chosen_action == 2 : # LEFT
260263 Y_new = Y
261264 X_new = X - 1 if X > 0 else X
262265 Z_new = Z
263266
264- elif chosen_action == 3 : # RIGHT
267+ elif chosen_action == 3 : # RIGHT
265268 Y_new = Y
266- X_new = X + 1 if X < agent .border else X
269+ X_new = X + 1 if X < agent .border else X
267270 Z_new = Z
268271
269- elif chosen_action == 4 : # IN
272+ elif chosen_action == 4 : # IN
270273 X_new = X
271274 Y_new = Y
272275 Z_new = Z + 1 if Z < agent .border else Z
273276
274- elif chosen_action == 5 : # OUT
277+ elif chosen_action == 5 : # OUT
275278 X_new = X
276279 Y_new = Y
277- Z_new = Z - 1 if Z > agent .border else Z
280+ Z_new = Z - 1 if Z > agent .border else Z
278281
279- elif chosen_action == 6 : # STAY
282+ elif chosen_action == 6 : # STAY
280283 Y_new , X_new , Z_new = Y , X , Z
281284
282285 return (X_new , Y_new , Z_new )
0 commit comments