11import numpy as np
22import tools .utils as utils
3+ from blockference .gridference import GridAgent
34
4- class ActiveGridference ():
5+ class ActiveGridference (GridAgent ):
56 """
67 The ActiveInference class is to be used to create a generative model to be used in cadCAD simulations.
78 The current focus is on discrete spaces.
@@ -14,7 +15,8 @@ class ActiveGridference():
1415 - (initial state) D -> the generative model's prior belief over hidden states at the first timestep
1516 - (affordances) E -> the generative model's available actions
1617 """
17- def __init__ (self , grid , planning_length : int = 2 , env_state : tuple = (0 , 0 ), ) -> None :
18+ def __init__ (self , planning_length : int = 2 , env_state : tuple = (0 , 0 ), ) -> None :
19+ super ().__init__ ()
1820 self .A = None
1921 self .B = None
2022 self .C = None
@@ -24,7 +26,6 @@ def __init__(self, grid, planning_length: int = 2, env_state: tuple = (0, 0), )
2426 self .policy_len = planning_length
2527
2628 # environment
27- self .grid = grid
2829 self .n_states = len (self .grid )
2930 self .n_observations = len (self .grid )
3031 self .border = np .sqrt (self .n_states ) - 1
@@ -56,7 +57,7 @@ def get_B(self):
5657
5758 for curr_state , grid_location in enumerate (self .grid ):
5859
59- y , x = grid_location
60+ y , x , z = grid_location
6061
6162 if action_label == "UP" :
6263 next_y = y - 1 if y > 0 else y
@@ -73,7 +74,7 @@ def get_B(self):
7374 elif action_label == "STAY" :
7475 next_x = x
7576 next_y = y
76- new_location = (next_y , next_x )
77+ new_location = (next_y , next_x , 0 )
7778 next_state = self .grid .index (new_location )
7879 self .B [next_state , curr_state , action_id ] = 1.0
7980
0 commit comments