@@ -4234,43 +4234,62 @@ def test_env_reset_with_hash(self, stateful, include_san):
42344234 td_check = env .reset (td .select ("fen_hash" ))
42354235 assert (td_check == td ).all ()
42364236
4237- @pytest .mark .parametrize ("include_fen" , [False , True ])
4238- @pytest .mark .parametrize ("include_pgn" , [False , True ])
4237+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[False , True ], [True , False ]])
42394238 @pytest .mark .parametrize ("stateful" , [False , True ])
4240- @pytest .mark .parametrize ("mask_actions" , [False , True ])
4241- def test_all_actions (self , include_fen , include_pgn , stateful , mask_actions ):
4242- if not stateful and not include_fen and not include_pgn :
4243- # pytest.skip("fen or pgn must be included if not stateful")
4244- return
4245-
4239+ @pytest .mark .parametrize ("include_hash" , [False , True ])
4240+ @pytest .mark .parametrize ("include_san" , [False , True ])
4241+ @pytest .mark .parametrize ("append_transform" , [False , True ])
4242+ @pytest .mark .parametrize ("mask_actions" , [True ])
4243+ def test_all_actions (
4244+ self ,
4245+ include_fen ,
4246+ include_pgn ,
4247+ stateful ,
4248+ include_hash ,
4249+ include_san ,
4250+ append_transform ,
4251+ mask_actions ,
4252+ ):
42464253 env = ChessEnv (
42474254 include_fen = include_fen ,
42484255 include_pgn = include_pgn ,
4256+ include_san = include_san ,
4257+ include_hash = include_hash ,
4258+ include_hash_inv = include_hash ,
42494259 stateful = stateful ,
42504260 mask_actions = mask_actions ,
42514261 )
4252- td = env .reset ()
42534262
4254- if not mask_actions :
4255- with pytest .raises (RuntimeError , match = "Cannot generate legal actions" ):
4256- env .all_actions ()
4257- return
4263+ def transform_reward (td ):
4264+ if "reward" not in td :
4265+ return td
4266+ reward = td ["reward" ]
4267+ if reward == 0.5 :
4268+ td ["reward" ] = 0
4269+ elif reward == 1 and td ["turn" ]:
4270+ td ["reward" ] = - td ["reward" ]
4271+ return td
4272+
4273+ if append_transform :
4274+ env = env .append_transform (transform_reward )
4275+
4276+ check_env_specs (env )
4277+
4278+ td = env .reset ()
42584279
42594280 # Choose random actions from the output of `all_actions`
4260- for _ in range (100 ):
4261- if stateful :
4262- all_actions = env .all_actions ()
4263- else :
4281+ for step_idx in range (100 ):
4282+ if step_idx % 5 == 0 :
42644283 # Reset theinitial state first, just to make sure
42654284 # `all_actions` knows how to get the board state from the input.
42664285 env .reset ()
4267- all_actions = env .all_actions (td .clone ())
4286+ all_actions = env .all_actions (td .clone ())
42684287
42694288 # Choose some random actions and make sure they match exactly one of
42704289 # the actions from `all_actions`. This part is not tested when
42714290 # `mask_actions == False`, because `rand_action` can pick illegal
42724291 # actions in that case.
4273- if mask_actions :
4292+ if mask_actions and step_idx % 4 == 0 :
42744293 # TODO: Something is wrong in `ChessEnv.rand_action` which makes
42754294 # it fail to work properly for stateless mode. It doesn't know
42764295 # how to correctly reset the board state to what is given in the
@@ -4287,7 +4306,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
42874306
42884307 action_idx = torch .randint (0 , all_actions .shape [0 ], ()).item ()
42894308 chosen_action = all_actions [action_idx ]
4290- td = env .step (td .update (chosen_action ))["next" ]
4309+ td_new = env .step (td .update (chosen_action ).clone ())
4310+ assert (td == td_new .exclude ("next" )).all ()
4311+ td = td_new ["next" ]
42914312
42924313 if td ["done" ]:
42934314 td = env .reset ()
0 commit comments