Skip to content

Commit ee83559

Browse files
committed
fix merge
2 parents a48577c + b1952ca commit ee83559

File tree

9 files changed

+451
-369
lines changed

9 files changed

+451
-369
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
__pycache__*
2+
*.ini
3+
*.xml
4+
*.iml

blockference/gridference.py

Lines changed: 59 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
import enum
1+
import enum # not currently used? Is enum used in another script? or can remove.
22
import sys
3-
4-
from matplotlib.pyplot import grid
5-
6-
# adding tools to the system path
7-
sys.path.insert(0, '../tools/')
8-
93
from tools.model import ActiveGridference
104
from tools.control import construct_policies
115
import tools.utils as u
126
import random as rand
137
import itertools
148

9+
from matplotlib.pyplot import grid
10+
11+
# adding tools to the system path
12+
sys.path.insert(0, '../tools/')
13+
1514

1615
def actinf_planning_single(agent, env_state, A, B, C, prior):
17-
policies = construct_policies([agent.n_states], [len(agent.E)], policy_len = agent.policy_len)
16+
policies = construct_policies([agent.n_states],
17+
[len(agent.E)],
18+
policy_len=agent.policy_len)
1819
# get obs_idx
1920
obs_idx = grid.index(env_state)
2021

@@ -34,7 +35,7 @@ def actinf_planning_single(agent, env_state, A, B, C, prior):
3435
chosen_action = u.sample(P_u)
3536

3637
# calc next prior
37-
prior = B[:,:,chosen_action].dot(qs_current)
38+
prior = B[:, :, chosen_action].dot(qs_current)
3839

3940
# update env state
4041
# action_label = params['actions'][chosen_action]
@@ -43,42 +44,43 @@ def actinf_planning_single(agent, env_state, A, B, C, prior):
4344
Y_new = Y
4445
X_new = X
4546

46-
if chosen_action == 0: # UP
47-
47+
if chosen_action == 0: # UP
48+
4849
Y_new = Y - 1 if Y > 0 else Y
4950
X_new = X
5051

51-
elif chosen_action == 1: # DOWN
52+
elif chosen_action == 1: # DOWN
5253

5354
Y_new = Y + 1 if Y < agent.border else Y
5455
X_new = X
5556

56-
elif chosen_action == 2: # LEFT
57+
elif chosen_action == 2: # LEFT
5758
Y_new = Y
5859
X_new = X - 1 if X > 0 else X
5960

60-
elif chosen_action == 3: # RIGHT
61+
elif chosen_action == 3: # RIGHT
6162
Y_new = Y
62-
X_new = X +1 if X < agent.border else X
63+
X_new = X + 1 if X < agent.border else X
6364

64-
elif chosen_action == 4: # STAY
65-
Y_new, X_new = Y, X
66-
67-
current_state = (Y_new, X_new) # store the new grid location
65+
elif chosen_action == 4: # STAY
66+
Y_new, X_new = Y, X
67+
68+
current_state = (Y_new, X_new) # store the new grid location
6869

6970
return {'update_prior': prior,
7071
'update_env': current_state,
7172
'update_action': chosen_action,
7273
'update_inference': qs_current}
7374

75+
7476
def actinf_graph(agent_network):
7577

7678
# list of all updates to the agents in the network
7779
agent_updates = []
7880

7981
for agent in agent_network.nodes:
8082

81-
policies = construct_policies([agent_network.nodes[agent]['agent'].n_states], [len(agent_network.nodes[agent]['agent'].E)], policy_len = agent_network.nodes[agent]['agent'].policy_len)
83+
policies = construct_policies([agent_network.nodes[agent]['agent'].n_states], [len(agent_network.nodes[agent]['agent'].E)], policy_len=agent_network.nodes[agent]['agent'].policy_len)
8284
# get obs_idx
8385
obs_idx = grid.index(agent_network.nodes[agent]['env_state'])
8486

@@ -92,12 +94,12 @@ def actinf_graph(agent_network):
9294
Q_pi = u.softmax(-_G)
9395
# compute the probability of each action
9496
P_u = u.compute_prob_actions(agent_network.nodes[agent]['agent'].E, policies, Q_pi)
95-
97+
9698
# sample action
9799
chosen_action = u.sample(P_u)
98100

99101
# calc next prior
100-
prior = agent_network.nodes[agent]['prior_B'][:,:,chosen_action].dot(qs_current)
102+
prior = agent_network.nodes[agent]['prior_B'][:, :, chosen_action].dot(qs_current)
101103

102104
# update env state
103105
# action_label = params['actions'][chosen_action]
@@ -107,28 +109,28 @@ def actinf_graph(agent_network):
107109
X_new = X
108110
# here
109111

110-
if chosen_action == 0: # UP
111-
112+
if chosen_action == 0: # UP
113+
112114
Y_new = Y - 1 if Y > 0 else Y
113115
X_new = X
114116

115-
elif chosen_action == 1: # DOWN
117+
elif chosen_action == 1: # DOWN
116118

117119
Y_new = Y + 1 if Y < agent_network.nodes[agent]['agent'].border else Y
118120
X_new = X
119121

120-
elif chosen_action == 2: # LEFT
122+
elif chosen_action == 2: # LEFT
121123
Y_new = Y
122124
X_new = X - 1 if X > 0 else X
123125

124-
elif chosen_action == 3: # RIGHT
126+
elif chosen_action == 3: # RIGHT
125127
Y_new = Y
126-
X_new = X +1 if X < agent_network.nodes[agent]['agent'].border else X
128+
X_new = X + 1 if X < agent_network.nodes[agent]['agent'].border else X
129+
130+
elif chosen_action == 4: # STAY
131+
Y_new, X_new = Y, X
127132

128-
elif chosen_action == 4: # STAY
129-
Y_new, X_new = Y, X
130-
131-
current_state = (Y_new, X_new) # store the new grid location
133+
current_state = (Y_new, X_new) # store the new grid location
132134
agent_update = {'source': agent,
133135
'update_prior': prior,
134136
'update_env': current_state,
@@ -138,6 +140,7 @@ def actinf_graph(agent_network):
138140

139141
return {'agent_updates': agent_updates}
140142

143+
141144
class GridAgent():
142145
def __init__(self, grid_len, num_agents, grid_dim=2) -> None:
143146
self.grid = self.get_grid(grid_len, grid_dim)
@@ -157,24 +160,24 @@ def init_agents(self, no_agents):
157160
# create new agent
158161
agent = ActiveGridference(self.grid)
159162
# generate target state
160-
target = (rand.randint(0,9), rand.randint(0,9))
163+
target = (rand.randint(0, 9), rand.randint(0, 9))
161164
# add target state
162165
agent.get_C(target + (0,))
163166
# all agents start in the same position
164-
start = (rand.randint(0,9), rand.randint(0,9))
167+
start = (rand.randint(0, 9), rand.randint(0, 9))
165168
agent.get_D(start + (1,))
166169

167170
agents[a] = agent
168171

169172
return agents
170-
173+
171174
def actinf_dict(self, agents_dict, g_agent):
172175
# list of all updates to the agents in the network
173176
agent_updates = []
174177

175178
for source, agent in agents_dict.items():
176179

177-
policies = construct_policies([agent.n_states], [len(agent.E)], policy_len = agent.policy_len)
180+
policies = construct_policies([agent.n_states], [len(agent.E)], policy_len=agent.policy_len)
178181
# get obs_idx
179182
obs_idx = g_agent.grid.index(agent.env_state)
180183

@@ -193,12 +196,12 @@ def actinf_dict(self, agents_dict, g_agent):
193196
chosen_action = u.sample(P_u)
194197

195198
# calc next prior
196-
prior = agent.B[:,:,chosen_action].dot(qs_current)
199+
prior = agent.B[:, :, chosen_action].dot(qs_current)
197200

198201
# update env state
199202
# action_label = params['actions'][chosen_action]
200-
201-
current_state = self.move_2d(agent, chosen_action) # store the new grid location
203+
204+
current_state = self.move_2d(agent, chosen_action) # store the new grid location
202205
agent_update = {'source': source,
203206
'update_prior': prior,
204207
'update_env': current_state,
@@ -214,25 +217,25 @@ def move_2d(self, agent, chosen_action):
214217
X_new = X
215218
# here
216219

217-
if chosen_action == 0: # UP
218-
220+
if chosen_action == 0: # UP
221+
219222
Y_new = Y - 1 if Y > 0 else Y
220223
X_new = X
221224

222-
elif chosen_action == 1: # DOWN
225+
elif chosen_action == 1: # DOWN
223226

224227
Y_new = Y + 1 if Y < agent.border else Y
225228
X_new = X
226229

227-
elif chosen_action == 2: # LEFT
230+
elif chosen_action == 2: # LEFT
228231
Y_new = Y
229232
X_new = X - 1 if X > 0 else X
230233

231-
elif chosen_action == 3: # RIGHT
234+
elif chosen_action == 3: # RIGHT
232235
Y_new = Y
233-
X_new = X +1 if X < agent.border else X
236+
X_new = X + 1 if X < agent.border else X
234237

235-
elif chosen_action == 4: # STAY
238+
elif chosen_action == 4: # STAY
236239
Y_new, X_new = Y, X
237240

238241
return (X_new, Y_new)
@@ -244,39 +247,39 @@ def move_3d(self, agent, chosen_action):
244247
Z_new = Z
245248
# here
246249

247-
if chosen_action == 0: # UP
248-
250+
if chosen_action == 0: # UP
251+
249252
Y_new = Y - 1 if Y > 0 else Y
250253
X_new = X
251254
Z_new = Z
252255

253-
elif chosen_action == 1: # DOWN
256+
elif chosen_action == 1: # DOWN
254257

255258
Y_new = Y + 1 if Y < agent.border else Y
256259
X_new = X
257260
Z_new = Z
258261

259-
elif chosen_action == 2: # LEFT
262+
elif chosen_action == 2: # LEFT
260263
Y_new = Y
261264
X_new = X - 1 if X > 0 else X
262265
Z_new = Z
263266

264-
elif chosen_action == 3: # RIGHT
267+
elif chosen_action == 3: # RIGHT
265268
Y_new = Y
266-
X_new = X +1 if X < agent.border else X
269+
X_new = X + 1 if X < agent.border else X
267270
Z_new = Z
268271

269-
elif chosen_action == 4: # IN
272+
elif chosen_action == 4: # IN
270273
X_new = X
271274
Y_new = Y
272275
Z_new = Z + 1 if Z < agent.border else Z
273276

274-
elif chosen_action == 5: # OUT
277+
elif chosen_action == 5: # OUT
275278
X_new = X
276279
Y_new = Y
277-
Z_new = Z -1 if Z > agent.border else Z
280+
Z_new = Z - 1 if Z > agent.border else Z
278281

279-
elif chosen_action == 6: # STAY
282+
elif chosen_action == 6: # STAY
280283
Y_new, X_new, Z_new = Y, X, Z
281284

282285
return (X_new, Y_new, Z_new)

0 commit comments

Comments
 (0)