Skip to content

Commit 8df7ead

Browse files
authored
Merge pull request #2 from ActiveInferenceLab/develop
Develop
2 parents 5ae194e + fc43711 commit 8df7ead

File tree

6 files changed

+1322
-761
lines changed

6 files changed

+1322
-761
lines changed

blockference/agent.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
""" Agent Class
5+
__author__: Conor Heins, Alexander Tschantz, Daphne Demekas, Brennan Klein
6+
"""
7+
8+
import warnings
9+
import numpy as np
10+
from pymdp import inference, control, learning
11+
from pymdp import utils, maths
12+
from pymdp.agent import Agent
13+
import copy
14+
15+
16+
17+
class Agent(Agent):
18+
"""
19+
Agent class
20+
"""
21+
22+
def __init__(
23+
self,
24+
A,
25+
B,
26+
C=None,
27+
D=None,
28+
E=None,
29+
pA=None,
30+
pB=None,
31+
pD=None,
32+
num_controls=None,
33+
policy_len=1,
34+
inference_horizon=1,
35+
control_fac_idx=None,
36+
policies=None,
37+
gamma=16.0,
38+
use_utility=True,
39+
use_states_info_gain=True,
40+
use_param_info_gain=False,
41+
action_selection="deterministic",
42+
inference_algo="VANILLA",
43+
inference_params=None,
44+
modalities_to_learn="all",
45+
lr_pA=1.0,
46+
factors_to_learn="all",
47+
lr_pB=1.0,
48+
lr_pD=1.0,
49+
use_BMA=True,
50+
policy_sep_prior=False,
51+
save_belief_hist=False
52+
):
53+
super().__init__()

blockference/envs/grid_env.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from blockference.gridference import *
2+
3+
4+
class GridAgent():
5+
def __init__(self, grid_len, num_agents, grid_dim=2) -> None:
6+
self.grid = self.get_grid(grid_len, grid_dim)
7+
self.grid_dim = grid_dim
8+
self.no_actions = 2 * grid_dim + 1
9+
self.agents = self.init_agents(num_agents)
10+
11+
def get_grid(self, grid_len, grid_dim):
12+
g = list(itertools.product(range(grid_len), repeat=grid_dim))
13+
for i, p in enumerate(g):
14+
g[i] += (0,)
15+
return g
16+
17+
def move_grid(self, agent, chosen_action):
18+
no_actions = 2 * self.grid_dim
19+
state = list(agent.env_state)
20+
new_state = state.copy()
21+
22+
# here
23+
24+
if chosen_action == 0: # STAY
25+
new_state = state
26+
else:
27+
if chosen_action % 2 == 1:
28+
index = (chosen_action+1) / 2
29+
new_state[index] = state[index] - 1 if state[index] > 0 else state[index]
30+
elif chosen_action % 2 == 0:
31+
index = chosen_action / 2
32+
new_state[index] = state[index] + 1 if state[index] < agent.border else state[index]
33+
return new_state
34+
35+
def init_agents(self, no_agents):
36+
# create a dict of agents
37+
agents = {}
38+
39+
for a in range(no_agents):
40+
# create new agent
41+
agent = ActiveGridference(self.grid)
42+
# generate target state
43+
target = (rand.randint(0, 9), rand.randint(0, 9))
44+
# add target state
45+
agent.get_C(target + (0,))
46+
# all agents start in the same position
47+
start = (rand.randint(0, 9), rand.randint(0, 9))
48+
agent.get_D(start + (1,))
49+
50+
agents[a] = agent
51+
52+
return agents
53+
54+
def actinf_dict(self, agents_dict, g_agent):
55+
# list of all updates to the agents in the network
56+
agent_updates = []
57+
58+
for source, agent in agents_dict.items():
59+
60+
policies = construct_policies([agent.n_states], [len(agent.E)], policy_len=agent.policy_len)
61+
# get obs_idx
62+
obs_idx = g_agent.grid.index(agent.env_state)
63+
64+
# infer_states
65+
qs_current = u.infer_states(obs_idx, agent.A, agent.prior)
66+
67+
# calc efe
68+
_G = u.calculate_G_policies(agent.A, agent.B, agent.C, qs_current, policies=policies)
69+
70+
# calc action posterior
71+
Q_pi = u.softmax(-_G)
72+
# compute the probability of each action
73+
P_u = u.compute_prob_actions(agent.E, policies, Q_pi)
74+
75+
# sample action
76+
chosen_action = u.sample(P_u)
77+
78+
# calc next prior
79+
prior = agent.B[:, :, chosen_action].dot(qs_current)
80+
81+
# update env state
82+
# action_label = params['actions'][chosen_action]
83+
84+
current_state = self.move_2d(agent, chosen_action) # store the new grid location
85+
agent_update = {'source': source,
86+
'update_prior': prior,
87+
'update_env': current_state,
88+
'update_action': chosen_action,
89+
'update_inference': qs_current}
90+
agent_updates.append(agent_update)
91+
92+
return {'agent_updates': agent_updates}
93+
94+
def move_2d(self, agent, chosen_action):
95+
(Y, X) = agent.env_state
96+
Y_new = Y
97+
X_new = X
98+
# here
99+
100+
if chosen_action == 0: # UP
101+
102+
Y_new = Y - 1 if Y > 0 else Y
103+
X_new = X
104+
105+
elif chosen_action == 1: # DOWN
106+
107+
Y_new = Y + 1 if Y < agent.border else Y
108+
X_new = X
109+
110+
elif chosen_action == 2: # LEFT
111+
Y_new = Y
112+
X_new = X - 1 if X > 0 else X
113+
114+
elif chosen_action == 3: # RIGHT
115+
Y_new = Y
116+
X_new = X + 1 if X < agent.border else X
117+
118+
elif chosen_action == 4: # STAY
119+
Y_new, X_new = Y, X
120+
121+
return (X_new, Y_new)
122+
123+
def move_3d(self, agent, chosen_action):
124+
(Y, X, Z) = agent.env_state
125+
Y_new = Y
126+
X_new = X
127+
Z_new = Z
128+
# here
129+
130+
if chosen_action == 0: # UP
131+
132+
Y_new = Y - 1 if Y > 0 else Y
133+
X_new = X
134+
Z_new = Z
135+
136+
elif chosen_action == 1: # DOWN
137+
138+
Y_new = Y + 1 if Y < agent.border else Y
139+
X_new = X
140+
Z_new = Z
141+
142+
elif chosen_action == 2: # LEFT
143+
Y_new = Y
144+
X_new = X - 1 if X > 0 else X
145+
Z_new = Z
146+
147+
elif chosen_action == 3: # RIGHT
148+
Y_new = Y
149+
X_new = X + 1 if X < agent.border else X
150+
Z_new = Z
151+
152+
elif chosen_action == 4: # IN
153+
X_new = X
154+
Y_new = Y
155+
Z_new = Z + 1 if Z < agent.border else Z
156+
157+
elif chosen_action == 5: # OUT
158+
X_new = X
159+
Y_new = Y
160+
Z_new = Z - 1 if Z > agent.border else Z
161+
162+
elif chosen_action == 6: # STAY
163+
Y_new, X_new, Z_new = Y, X, Z
164+
165+
return (X_new, Y_new, Z_new)

0 commit comments

Comments
 (0)