@@ -4058,6 +4058,35 @@ def test_chess_tokenized(self):
40584058 assert "fen" in ftd ["next" ]
40594059 env .check_env_specs ()
40604060
4061+ @pytest .mark .parametrize ("stateful" , [False , True ])
4062+ @pytest .mark .parametrize ("include_san" , [False , True ])
4063+ def test_env_reset_with_hash (self , stateful , include_san ):
4064+ env = ChessEnv (
4065+ include_fen = True ,
4066+ include_hash = True ,
4067+ include_hash_inv = True ,
4068+ stateful = stateful ,
4069+ include_san = include_san ,
4070+ )
4071+ cases = [
4072+ # (fen, num_legal_moves)
4073+ ("5R1k/8/8/8/6R1/8/8/5K2 b - - 0 1" , 1 ),
4074+ ("8/8/2kq4/4K3/1R3Q2/8/8/8 w - - 0 1" , 2 ),
4075+ ("6R1/8/8/4rq2/3pPk2/5n2/8/2B1R2K b - e3 0 1" , 2 ),
4076+ ]
4077+ for fen , num_legal_moves in cases :
4078+ # Load the state by fen.
4079+ td = env .reset (TensorDict ({"fen" : fen }))
4080+ assert td ["fen" ] == fen
4081+ assert td ["action_mask" ].sum () == num_legal_moves
4082+ # Reset to initial state just to make sure that the next reset
4083+ # actually changes the state.
4084+ assert env .reset ()["action_mask" ].sum () == 20
4085+ # Load the state by fen hash and make sure it gives the same output
4086+ # as before.
4087+ td_check = env .reset (td .select ("fen_hash" ))
4088+ assert (td_check == td ).all ()
4089+
40614090
40624091class TestCustomEnvs :
40634092 def test_tictactoe_env (self ):
0 commit comments