Skip to content

Commit cce0eee

Browse files
authored
Merge pull request #3 from ActiveInferenceLab/develop
Develop
2 parents b99a3be + 2c29200 commit cce0eee

19 files changed

+1000
-1930
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ __pycache__*
33
*.xml
44
*.iml
55
*.ipynb_checkpoints/
6+
./blockference/.ipynb_checkpoints
7+
.DS_Store
8+
**checkpoint.py

blockference/.ipynb_checkpoints/gridference-checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def __init__(self, grid, planning_length: int = 2, env_state: tuple = (0, 0)) ->
345345

346346
def get_A(self):
347347
"""
348-
State Matrix (identity matrix)
348+
State Matrix (identity matrix for the single agent gridworld)
349349
Params:
350350
- n_observations: int: number of possible observations
351351
- n_states: int: number of possible states

blockference/agent.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,33 @@ def __init__(
4848
lr_pD=1.0,
4949
use_BMA=True,
5050
policy_sep_prior=False,
51-
save_belief_hist=False
51+
save_belief_hist=False,
5252
):
53-
super().__init__()
53+
super().__init__(A,
54+
B,
55+
C=None,
56+
D=None,
57+
E=None,
58+
pA=None,
59+
pB=None,
60+
pD=None,
61+
num_controls=None,
62+
policy_len=1,
63+
inference_horizon=1,
64+
control_fac_idx=None,
65+
policies=None,
66+
gamma=16.0,
67+
use_utility=True,
68+
use_states_info_gain=True,
69+
use_param_info_gain=False,
70+
action_selection="deterministic",
71+
inference_algo="VANILLA",
72+
inference_params=None,
73+
modalities_to_learn="all",
74+
lr_pA=1.0,
75+
factors_to_learn="all",
76+
lr_pB=1.0,
77+
lr_pD=1.0,
78+
use_BMA=True,
79+
policy_sep_prior=False,
80+
save_belief_hist=False,)

blockference/envs/grid_env.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ def __init__(self, grid_len, num_agents, grid_dim=2) -> None:
66
self.grid = self.get_grid(grid_len, grid_dim)
77
self.grid_dim = grid_dim
88
self.no_actions = 2 * grid_dim + 1
9-
self.agents = self.init_agents(num_agents)
9+
self.n_observations = grid_len ** 2
10+
self.n_states = grid_len ** 2
11+
self.border = np.sqrt(self.n_states) - 1
12+
# self.agents = self.init_agents(num_agents)
1013

1114
def get_grid(self, grid_len, grid_dim):
1215
g = list(itertools.product(range(grid_len), repeat=grid_dim))
13-
for i, p in enumerate(g):
14-
g[i] += (0,)
1516
return g
1617

1718
def move_grid(self, agent, chosen_action):
@@ -29,7 +30,7 @@ def move_grid(self, agent, chosen_action):
2930
new_state[index] = state[index] - 1 if state[index] > 0 else state[index]
3031
elif chosen_action % 2 == 0:
3132
index = chosen_action / 2
32-
new_state[index] = state[index] + 1 if state[index] < agent.border else state[index]
33+
new_state[index] = state[index] + 1 if state[index] < self.border else state[index]
3334
return new_state
3435

3536
def init_agents(self, no_agents):

blockference/tools/.ipynb_checkpoints/__init__-checkpoint.py

Whitespace-only changes.

blockference/tools/.ipynb_checkpoints/policy-checkpoint.py

Lines changed: 0 additions & 136 deletions
This file was deleted.

0 commit comments

Comments
 (0)