Skip to content

Commit a48577c

Browse files
committed
WIP: update model
1 parent 4775e71 commit a48577c

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tools/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22
import tools.utils as utils
3+
from blockference.gridference import GridAgent
34

4-
class ActiveGridference():
5+
class ActiveGridference(GridAgent):
56
"""
67
The ActiveInference class is to be used to create a generative model to be used in cadCAD simulations.
78
The current focus is on discrete spaces.
@@ -14,7 +15,8 @@ class ActiveGridference():
1415
- (initial state) D -> the generative model's prior belief over hidden states at the first timestep
1516
- (affordances) E -> the generative model's available actions
1617
"""
17-
def __init__(self, grid, planning_length: int = 2, env_state: tuple = (0, 0), ) -> None:
18+
def __init__(self, planning_length: int = 2, env_state: tuple = (0, 0), ) -> None:
19+
super().__init__()
1820
self.A = None
1921
self.B = None
2022
self.C = None
@@ -24,7 +26,6 @@ def __init__(self, grid, planning_length: int = 2, env_state: tuple = (0, 0), )
2426
self.policy_len = planning_length
2527

2628
# environment
27-
self.grid = grid
2829
self.n_states = len(self.grid)
2930
self.n_observations = len(self.grid)
3031
self.border = np.sqrt(self.n_states) - 1
@@ -56,7 +57,7 @@ def get_B(self):
5657

5758
for curr_state, grid_location in enumerate(self.grid):
5859

59-
y, x = grid_location
60+
y, x, z = grid_location
6061

6162
if action_label == "UP":
6263
next_y = y - 1 if y > 0 else y
@@ -73,7 +74,7 @@ def get_B(self):
7374
elif action_label == "STAY":
7475
next_x = x
7576
next_y = y
76-
new_location = (next_y, next_x)
77+
new_location = (next_y, next_x, 0)
7778
next_state = self.grid.index(new_location)
7879
self.B[next_state, curr_state, action_id] = 1.0
7980

0 commit comments

Comments
 (0)