@@ -3346,6 +3346,10 @@ def test_batched_dynamic(self, break_when_any_done):
33463346 )
33473347 del env_no_buffers
33483348 gc .collect ()
3349+ # print(dummy_rollouts)
3350+ # print(rollout_no_buffers_serial)
3351+ # # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)):
3352+ # assert_allclose_td(a, b)
33493353 assert_allclose_td (
33503354 dummy_rollouts .exclude ("action" ),
33513355 rollout_no_buffers_serial .exclude ("action" ),
@@ -3441,35 +3445,146 @@ def test_partial_rest(self, batched):
34413445
34423446# fen strings for board positions generated with:
34433447# https://lichess.org/editor
3444- @pytest .mark .parametrize ("stateful" , [False , True ])
34453448@pytest .mark .skipif (not _has_chess , reason = "chess not found" )
34463449class TestChessEnv :
3447- def test_env (self , stateful ):
3448- env = ChessEnv (stateful = stateful )
3449- check_env_specs (env )
3450+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3451+ @pytest .mark .parametrize ("include_fen" , [False , True ])
3452+ @pytest .mark .parametrize ("stateful" , [False , True ])
3453+ @pytest .mark .parametrize ("include_hash" , [False , True ])
3454+ @pytest .mark .parametrize ("include_san" , [False , True ])
3455+ def test_env (self , stateful , include_pgn , include_fen , include_hash , include_san ):
3456+ with pytest .raises (
3457+ RuntimeError , match = "At least one state representation"
3458+ ) if not stateful and not include_pgn and not include_fen else contextlib .nullcontext ():
3459+ env = ChessEnv (
3460+ stateful = stateful ,
3461+ include_pgn = include_pgn ,
3462+ include_fen = include_fen ,
3463+ include_hash = include_hash ,
3464+ include_san = include_san ,
3465+ )
3466+ check_env_specs (env )
3467+ if include_hash :
3468+ if include_fen :
3469+ assert "fen_hash" in env .observation_spec .keys ()
3470+ if include_pgn :
3471+ assert "pgn_hash" in env .observation_spec .keys ()
3472+ if include_san :
3473+ assert "san_hash" in env .observation_spec .keys ()
3474+
3475+ def test_pgn_bijectivity (self ):
3476+ np .random .seed (0 )
3477+ pgn = ChessEnv ._PGN_RESTART
3478+ board = ChessEnv ._pgn_to_board (pgn )
3479+ pgn_prev = pgn
3480+ for _ in range (10 ):
3481+ moves = list (board .legal_moves )
3482+ move = np .random .choice (moves )
3483+ board .push (move )
3484+ pgn_move = ChessEnv ._board_to_pgn (board )
3485+ assert pgn_move != pgn_prev
3486+ assert pgn_move == ChessEnv ._board_to_pgn (ChessEnv ._pgn_to_board (pgn_move ))
3487+ assert pgn_move == ChessEnv ._add_move_to_pgn (pgn_prev , move )
3488+ pgn_prev = pgn_move
3489+
3490+ def test_consistency (self ):
3491+ env0_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = True )
3492+ env1_stateful = ChessEnv (stateful = True , include_pgn = False , include_fen = True )
3493+ env2_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = False )
3494+ env0_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = True )
3495+ env1_stateless = ChessEnv (stateful = False , include_pgn = False , include_fen = True )
3496+ env2_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = False )
3497+ torch .manual_seed (0 )
3498+ r1_stateless = env1_stateless .rollout (50 , break_when_any_done = False )
3499+ torch .manual_seed (0 )
3500+ r1_stateful = env1_stateful .rollout (50 , break_when_any_done = False )
3501+ torch .manual_seed (0 )
3502+ r2_stateless = env2_stateless .rollout (50 , break_when_any_done = False )
3503+ torch .manual_seed (0 )
3504+ r2_stateful = env2_stateful .rollout (50 , break_when_any_done = False )
3505+ torch .manual_seed (0 )
3506+ r0_stateless = env0_stateless .rollout (50 , break_when_any_done = False )
3507+ torch .manual_seed (0 )
3508+ r0_stateful = env0_stateful .rollout (50 , break_when_any_done = False )
3509+ assert (r0_stateless ["action" ] == r1_stateless ["action" ]).all ()
3510+ assert (r0_stateless ["action" ] == r2_stateless ["action" ]).all ()
3511+ assert (r0_stateless ["action" ] == r0_stateful ["action" ]).all ()
3512+ assert (r1_stateless ["action" ] == r1_stateful ["action" ]).all ()
3513+ assert (r2_stateless ["action" ] == r2_stateful ["action" ]).all ()
3514+
3515+ @pytest .mark .parametrize (
3516+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3517+ )
3518+ @pytest .mark .parametrize ("stateful" , [False , True ])
3519+ def test_san (self , stateful , include_fen , include_pgn ):
3520+ torch .manual_seed (0 )
3521+ env = ChessEnv (
3522+ stateful = stateful ,
3523+ include_pgn = include_pgn ,
3524+ include_fen = include_fen ,
3525+ include_san = True ,
3526+ )
3527+ r = env .rollout (100 , break_when_any_done = False )
3528+ sans = r ["next" , "san" ]
3529+ actions = [env .san_moves .index (san ) for san in sans ]
3530+ i = 0
3531+
3532+ def policy (td ):
3533+ nonlocal i
3534+ td ["action" ] = actions [i ]
3535+ i += 1
3536+ return td
34503537
3451- def test_rollout (self , stateful ):
3452- env = ChessEnv (stateful = stateful )
3453- env .rollout (5000 )
3538+ r2 = env .rollout (100 , policy = policy , break_when_any_done = False )
3539+ assert_allclose_td (r , r2 )
34543540
3455- def test_reset_white_to_move (self , stateful ):
3456- env = ChessEnv (stateful = stateful )
3541+ @pytest .mark .parametrize (
3542+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3543+ )
3544+ @pytest .mark .parametrize ("stateful" , [False , True ])
3545+ def test_rollout (self , stateful , include_pgn , include_fen ):
3546+ torch .manual_seed (0 )
3547+ env = ChessEnv (
3548+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3549+ )
3550+ r = env .rollout (500 , break_when_any_done = False )
3551+ assert r .shape == (500 ,)
3552+
3553+ @pytest .mark .parametrize (
3554+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3555+ )
3556+ @pytest .mark .parametrize ("stateful" , [False , True ])
3557+ def test_reset_white_to_move (self , stateful , include_pgn , include_fen ):
3558+ env = ChessEnv (
3559+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3560+ )
34573561 fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
34583562 td = env .reset (TensorDict ({"fen" : fen }))
34593563 assert td ["fen" ] == fen
3564+ if include_fen :
3565+ assert env .board .fen () == fen
34603566 assert td ["turn" ] == env .lib .WHITE
34613567 assert not td ["done" ]
34623568
3463- def test_reset_black_to_move (self , stateful ):
3464- env = ChessEnv (stateful = stateful )
3569+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3570+ @pytest .mark .parametrize ("stateful" , [False , True ])
3571+ def test_reset_black_to_move (self , stateful , include_pgn , include_fen ):
3572+ env = ChessEnv (
3573+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3574+ )
34653575 fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
34663576 td = env .reset (TensorDict ({"fen" : fen }))
34673577 assert td ["fen" ] == fen
3578+ assert env .board .fen () == fen
34683579 assert td ["turn" ] == env .lib .BLACK
34693580 assert not td ["done" ]
34703581
3471- def test_reset_done_error (self , stateful ):
3472- env = ChessEnv (stateful = stateful )
3582+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3583+ @pytest .mark .parametrize ("stateful" , [False , True ])
3584+ def test_reset_done_error (self , stateful , include_pgn , include_fen ):
3585+ env = ChessEnv (
3586+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3587+ )
34733588 fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
34743589 with pytest .raises (ValueError ) as e_info :
34753590 env .reset (TensorDict ({"fen" : fen }))
@@ -3480,12 +3595,19 @@ def test_reset_done_error(self, stateful):
34803595 @pytest .mark .parametrize (
34813596 "endstate" , ["white win" , "black win" , "stalemate" , "50 move" , "insufficient" ]
34823597 )
3483- def test_reward (self , stateful , reset_without_fen , endstate ):
3598+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3599+ @pytest .mark .parametrize ("include_fen" , [True ])
3600+ @pytest .mark .parametrize ("stateful" , [False , True ])
3601+ def test_reward (
3602+ self , stateful , reset_without_fen , endstate , include_pgn , include_fen
3603+ ):
34843604 if stateful and reset_without_fen :
34853605 # reset_without_fen is only used for stateless env
34863606 return
34873607
3488- env = ChessEnv (stateful = stateful )
3608+ env = ChessEnv (
3609+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3610+ )
34893611
34903612 if endstate == "white win" :
34913613 fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
@@ -3498,28 +3620,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
34983620 fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
34993621 expected_turn = env .lib .BLACK
35003622 move = "Rg1#"
3501- expected_reward = - 1
3623+ expected_reward = 1
35023624 expected_done = True
35033625
35043626 elif endstate == "stalemate" :
35053627 fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
35063628 expected_turn = env .lib .BLACK
35073629 move = "Rb7"
3508- expected_reward = 0
3630+ expected_reward = 0.5
35093631 expected_done = True
35103632
35113633 elif endstate == "insufficient" :
35123634 fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
35133635 expected_turn = env .lib .WHITE
35143636 move = "Kxd4"
3515- expected_reward = 0
3637+ expected_reward = 0.5
35163638 expected_done = True
35173639
35183640 elif endstate == "50 move" :
35193641 fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
35203642 expected_turn = env .lib .BLACK
35213643 move = "Kf7"
3522- expected_reward = 0
3644+ expected_reward = 0.5
35233645 expected_done = True
35243646
35253647 elif endstate == "not_done" :
@@ -3538,8 +3660,7 @@ def test_reward(self, stateful, reset_without_fen, endstate):
35383660 td = env .reset (TensorDict ({"fen" : fen }))
35393661 assert td ["turn" ] == expected_turn
35403662
3541- moves = env .get_legal_moves (None if stateful else td )
3542- td ["action" ] = moves .index (move )
3663+ td ["action" ] = env ._san_moves .index (move )
35433664 td = env .step (td )["next" ]
35443665 assert td ["done" ] == expected_done
35453666 assert td ["reward" ] == expected_reward
0 commit comments