131131from torchrl .envs .transforms .transforms import (
132132 AutoResetEnv ,
133133 AutoResetTransform ,
134+ Tokenizer ,
134135 Transform ,
135136)
136137from torchrl .envs .utils import (
@@ -3441,35 +3442,148 @@ def test_partial_rest(self, batched):
34413442
34423443# fen strings for board positions generated with:
34433444# https://lichess.org/editor
3444- @pytest .mark .parametrize ("stateful" , [False , True ])
34453445@pytest .mark .skipif (not _has_chess , reason = "chess not found" )
34463446class TestChessEnv :
3447- def test_env (self , stateful ):
3448- env = ChessEnv (stateful = stateful )
3449- check_env_specs (env )
3447+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3448+ @pytest .mark .parametrize ("include_fen" , [False , True ])
3449+ @pytest .mark .parametrize ("stateful" , [False , True ])
3450+ @pytest .mark .parametrize ("include_hash" , [False , True ])
3451+ @pytest .mark .parametrize ("include_san" , [False , True ])
3452+ def test_env (self , stateful , include_pgn , include_fen , include_hash , include_san ):
3453+ with pytest .raises (
3454+ RuntimeError , match = "At least one state representation"
3455+ ) if not stateful and not include_pgn and not include_fen else contextlib .nullcontext ():
3456+ env = ChessEnv (
3457+ stateful = stateful ,
3458+ include_pgn = include_pgn ,
3459+ include_fen = include_fen ,
3460+ include_hash = include_hash ,
3461+ include_san = include_san ,
3462+ )
3463+ # Because we always use mask_actions=True
3464+ assert isinstance (env , TransformedEnv )
3465+ check_env_specs (env )
3466+ if include_hash :
3467+ if include_fen :
3468+ assert "fen_hash" in env .observation_spec .keys ()
3469+ if include_pgn :
3470+ assert "pgn_hash" in env .observation_spec .keys ()
3471+ if include_san :
3472+ assert "san_hash" in env .observation_spec .keys ()
3473+
3474+ def test_pgn_bijectivity (self ):
3475+ np .random .seed (0 )
3476+ pgn = ChessEnv ._PGN_RESTART
3477+ board = ChessEnv ._pgn_to_board (pgn )
3478+ pgn_prev = pgn
3479+ for _ in range (10 ):
3480+ moves = list (board .legal_moves )
3481+ move = np .random .choice (moves )
3482+ board .push (move )
3483+ pgn_move = ChessEnv ._board_to_pgn (board )
3484+ assert pgn_move != pgn_prev
3485+ assert pgn_move == ChessEnv ._board_to_pgn (ChessEnv ._pgn_to_board (pgn_move ))
3486+ assert pgn_move == ChessEnv ._add_move_to_pgn (pgn_prev , move )
3487+ pgn_prev = pgn_move
3488+
3489+ def test_consistency (self ):
3490+ env0_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = True )
3491+ env1_stateful = ChessEnv (stateful = True , include_pgn = False , include_fen = True )
3492+ env2_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = False )
3493+ env0_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = True )
3494+ env1_stateless = ChessEnv (stateful = False , include_pgn = False , include_fen = True )
3495+ env2_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = False )
3496+ torch .manual_seed (0 )
3497+ r1_stateless = env1_stateless .rollout (50 , break_when_any_done = False )
3498+ torch .manual_seed (0 )
3499+ r1_stateful = env1_stateful .rollout (50 , break_when_any_done = False )
3500+ torch .manual_seed (0 )
3501+ r2_stateless = env2_stateless .rollout (50 , break_when_any_done = False )
3502+ torch .manual_seed (0 )
3503+ r2_stateful = env2_stateful .rollout (50 , break_when_any_done = False )
3504+ torch .manual_seed (0 )
3505+ r0_stateless = env0_stateless .rollout (50 , break_when_any_done = False )
3506+ torch .manual_seed (0 )
3507+ r0_stateful = env0_stateful .rollout (50 , break_when_any_done = False )
3508+ assert (r0_stateless ["action" ] == r1_stateless ["action" ]).all ()
3509+ assert (r0_stateless ["action" ] == r2_stateless ["action" ]).all ()
3510+ assert (r0_stateless ["action" ] == r0_stateful ["action" ]).all ()
3511+ assert (r1_stateless ["action" ] == r1_stateful ["action" ]).all ()
3512+ assert (r2_stateless ["action" ] == r2_stateful ["action" ]).all ()
3513+
3514+ @pytest .mark .parametrize (
3515+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3516+ )
3517+ @pytest .mark .parametrize ("stateful" , [False , True ])
3518+ def test_san (self , stateful , include_fen , include_pgn ):
3519+ torch .manual_seed (0 )
3520+ env = ChessEnv (
3521+ stateful = stateful ,
3522+ include_pgn = include_pgn ,
3523+ include_fen = include_fen ,
3524+ include_san = True ,
3525+ )
3526+ r = env .rollout (100 , break_when_any_done = False )
3527+ sans = r ["next" , "san" ]
3528+ actions = [env .san_moves .index (san ) for san in sans ]
3529+ i = 0
3530+
3531+ def policy (td ):
3532+ nonlocal i
3533+ td ["action" ] = actions [i ]
3534+ i += 1
3535+ return td
34503536
3451- def test_rollout (self , stateful ):
3452- env = ChessEnv (stateful = stateful )
3453- env .rollout (5000 )
3537+ r2 = env .rollout (100 , policy = policy , break_when_any_done = False )
3538+ assert_allclose_td (r , r2 )
34543539
3455- def test_reset_white_to_move (self , stateful ):
3456- env = ChessEnv (stateful = stateful )
3540+ @pytest .mark .parametrize (
3541+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3542+ )
3543+ @pytest .mark .parametrize ("stateful" , [False , True ])
3544+ def test_rollout (self , stateful , include_pgn , include_fen ):
3545+ torch .manual_seed (0 )
3546+ env = ChessEnv (
3547+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3548+ )
3549+ r = env .rollout (500 , break_when_any_done = False )
3550+ assert r .shape == (500 ,)
3551+
3552+ @pytest .mark .parametrize (
3553+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3554+ )
3555+ @pytest .mark .parametrize ("stateful" , [False , True ])
3556+ def test_reset_white_to_move (self , stateful , include_pgn , include_fen ):
3557+ env = ChessEnv (
3558+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3559+ )
34573560 fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
34583561 td = env .reset (TensorDict ({"fen" : fen }))
3459- assert td ["fen" ] == fen
3562+ if include_fen :
3563+ assert td ["fen" ] == fen
3564+ assert env .board .fen () == fen
34603565 assert td ["turn" ] == env .lib .WHITE
34613566 assert not td ["done" ]
34623567
3463- def test_reset_black_to_move (self , stateful ):
3464- env = ChessEnv (stateful = stateful )
3568+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3569+ @pytest .mark .parametrize ("stateful" , [False , True ])
3570+ def test_reset_black_to_move (self , stateful , include_pgn , include_fen ):
3571+ env = ChessEnv (
3572+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3573+ )
34653574 fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
34663575 td = env .reset (TensorDict ({"fen" : fen }))
34673576 assert td ["fen" ] == fen
3577+ assert env .board .fen () == fen
34683578 assert td ["turn" ] == env .lib .BLACK
34693579 assert not td ["done" ]
34703580
3471- def test_reset_done_error (self , stateful ):
3472- env = ChessEnv (stateful = stateful )
3581+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3582+ @pytest .mark .parametrize ("stateful" , [False , True ])
3583+ def test_reset_done_error (self , stateful , include_pgn , include_fen ):
3584+ env = ChessEnv (
3585+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3586+ )
34733587 fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
34743588 with pytest .raises (ValueError ) as e_info :
34753589 env .reset (TensorDict ({"fen" : fen }))
@@ -3480,12 +3594,19 @@ def test_reset_done_error(self, stateful):
34803594 @pytest .mark .parametrize (
34813595 "endstate" , ["white win" , "black win" , "stalemate" , "50 move" , "insufficient" ]
34823596 )
3483- def test_reward (self , stateful , reset_without_fen , endstate ):
3597+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3598+ @pytest .mark .parametrize ("include_fen" , [True ])
3599+ @pytest .mark .parametrize ("stateful" , [False , True ])
3600+ def test_reward (
3601+ self , stateful , reset_without_fen , endstate , include_pgn , include_fen
3602+ ):
34843603 if stateful and reset_without_fen :
34853604 # reset_without_fen is only used for stateless env
34863605 return
34873606
3488- env = ChessEnv (stateful = stateful )
3607+ env = ChessEnv (
3608+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3609+ )
34893610
34903611 if endstate == "white win" :
34913612 fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
@@ -3498,28 +3619,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
34983619 fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
34993620 expected_turn = env .lib .BLACK
35003621 move = "Rg1#"
3501- expected_reward = - 1
3622+ expected_reward = 1
35023623 expected_done = True
35033624
35043625 elif endstate == "stalemate" :
35053626 fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
35063627 expected_turn = env .lib .BLACK
35073628 move = "Rb7"
3508- expected_reward = 0
3629+ expected_reward = 0.5
35093630 expected_done = True
35103631
35113632 elif endstate == "insufficient" :
35123633 fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
35133634 expected_turn = env .lib .WHITE
35143635 move = "Kxd4"
3515- expected_reward = 0
3636+ expected_reward = 0.5
35163637 expected_done = True
35173638
35183639 elif endstate == "50 move" :
35193640 fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
35203641 expected_turn = env .lib .BLACK
35213642 move = "Kf7"
3522- expected_reward = 0
3643+ expected_reward = 0.5
35233644 expected_done = True
35243645
35253646 elif endstate == "not_done" :
@@ -3538,13 +3659,33 @@ def test_reward(self, stateful, reset_without_fen, endstate):
35383659 td = env .reset (TensorDict ({"fen" : fen }))
35393660 assert td ["turn" ] == expected_turn
35403661
3541- moves = env .get_legal_moves (None if stateful else td )
3542- td ["action" ] = moves .index (move )
3662+ td ["action" ] = env ._san_moves .index (move )
35433663 td = env .step (td )["next" ]
35443664 assert td ["done" ] == expected_done
35453665 assert td ["reward" ] == expected_reward
35463666 assert td ["turn" ] == (not expected_turn )
35473667
3668+ def test_chess_tokenized (self ):
3669+ env = ChessEnv (include_fen = True , stateful = True , include_san = True )
3670+ assert isinstance (env .observation_spec ["fen" ], NonTensor )
3671+ env = env .append_transform (
3672+ Tokenizer (in_keys = ["fen" ], out_keys = ["fen_tokenized" ])
3673+ )
3674+ assert isinstance (env .observation_spec ["fen" ], NonTensor )
3675+ env .transform .transform_output_spec (env .base_env .output_spec )
3676+ env .transform .transform_input_spec (env .base_env .input_spec )
3677+ r = env .rollout (10 , return_contiguous = False )
3678+ assert "fen_tokenized" in r
3679+ assert "fen" in r
3680+ assert "fen_tokenized" in r ["next" ]
3681+ assert "fen" in r ["next" ]
3682+ ftd = env .fake_tensordict ()
3683+ assert "fen_tokenized" in ftd
3684+ assert "fen" in ftd
3685+ assert "fen_tokenized" in ftd ["next" ]
3686+ assert "fen" in ftd ["next" ]
3687+ env .check_env_specs ()
3688+
35483689
35493690class TestCustomEnvs :
35503691 def test_tictactoe_env (self ):
0 commit comments