@@ -3291,6 +3291,10 @@ def test_batched_dynamic(self, break_when_any_done):
32913291 )
32923292 del env_no_buffers
32933293 gc .collect ()
3294+ # print(dummy_rollouts)
3295+ # print(rollout_no_buffers_serial)
3296+ # # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)):
3297+ # assert_allclose_td(a, b)
32943298 assert_allclose_td (
32953299 dummy_rollouts .exclude ("action" ),
32963300 rollout_no_buffers_serial .exclude ("action" ),
@@ -3386,35 +3390,107 @@ def test_partial_rest(self, batched):
33863390
33873391# fen strings for board positions generated with:
33883392# https://lichess.org/editor
3389- @pytest .mark .parametrize ("stateful" , [False , True ])
33903393@pytest .mark .skipif (not _has_chess , reason = "chess not found" )
33913394class TestChessEnv :
3392- def test_env (self , stateful ):
3393- env = ChessEnv (stateful = stateful )
3394- check_env_specs (env )
3395+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3396+ @pytest .mark .parametrize ("include_fen" , [False , True ])
3397+ @pytest .mark .parametrize ("stateful" , [False , True ])
3398+ def test_env (self , stateful , include_pgn , include_fen ):
3399+ with pytest .raises (
3400+ RuntimeError , match = "At least one state representation"
3401+ ) if not stateful and not include_pgn and not include_fen else contextlib .nullcontext ():
3402+ env = ChessEnv (
3403+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3404+ )
3405+ check_env_specs (env )
33953406
3396- def test_rollout (self , stateful ):
3397- env = ChessEnv (stateful = stateful )
3398- env .rollout (5000 )
3407+ def test_pgn_bijectivity (self ):
3408+ np .random .seed (0 )
3409+ pgn = ChessEnv ._PGN_RESTART
3410+ board = ChessEnv ._pgn_to_board (pgn )
3411+ pgn_prev = pgn
3412+ for _ in range (10 ):
3413+ moves = list (board .legal_moves )
3414+ move = np .random .choice (moves )
3415+ board .push (move )
3416+ pgn_move = ChessEnv ._board_to_pgn (board )
3417+ assert pgn_move != pgn_prev
3418+ assert pgn_move == ChessEnv ._board_to_pgn (ChessEnv ._pgn_to_board (pgn_move ))
3419+ assert pgn_move == ChessEnv ._add_move_to_pgn (pgn_prev , move )
3420+ pgn_prev = pgn_move
3421+
3422+ def test_consistency (self ):
3423+ env0_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = True )
3424+ env1_stateful = ChessEnv (stateful = True , include_pgn = False , include_fen = True )
3425+ env2_stateful = ChessEnv (stateful = True , include_pgn = True , include_fen = False )
3426+ env0_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = True )
3427+ env1_stateless = ChessEnv (stateful = False , include_pgn = False , include_fen = True )
3428+ env2_stateless = ChessEnv (stateful = False , include_pgn = True , include_fen = False )
3429+ torch .manual_seed (0 )
3430+ r1_stateless = env1_stateless .rollout (50 , break_when_any_done = False )
3431+ torch .manual_seed (0 )
3432+ r1_stateful = env1_stateful .rollout (50 , break_when_any_done = False )
3433+ torch .manual_seed (0 )
3434+ r2_stateless = env2_stateless .rollout (50 , break_when_any_done = False )
3435+ torch .manual_seed (0 )
3436+ r2_stateful = env2_stateful .rollout (50 , break_when_any_done = False )
3437+ torch .manual_seed (0 )
3438+ r0_stateless = env0_stateless .rollout (50 , break_when_any_done = False )
3439+ torch .manual_seed (0 )
3440+ r0_stateful = env0_stateful .rollout (50 , break_when_any_done = False )
3441+ assert (r0_stateless ["action" ] == r1_stateless ["action" ]).all ()
3442+ assert (r0_stateless ["action" ] == r2_stateless ["action" ]).all ()
3443+ assert (r0_stateless ["action" ] == r0_stateful ["action" ]).all ()
3444+ assert (r1_stateless ["action" ] == r1_stateful ["action" ]).all ()
3445+ assert (r2_stateless ["action" ] == r2_stateful ["action" ]).all ()
3446+
3447+ @pytest .mark .parametrize (
3448+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3449+ )
3450+ @pytest .mark .parametrize ("stateful" , [False , True ])
3451+ def test_rollout (self , stateful , include_pgn , include_fen ):
3452+ torch .manual_seed (0 )
3453+ env = ChessEnv (
3454+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3455+ )
3456+ r = env .rollout (500 , break_when_any_done = False )
3457+ assert r .shape == (500 ,)
33993458
3400- def test_reset_white_to_move (self , stateful ):
3401- env = ChessEnv (stateful = stateful )
3459+ @pytest .mark .parametrize (
3460+ "include_fen,include_pgn" , [[True , False ], [False , True ], [True , True ]]
3461+ )
3462+ @pytest .mark .parametrize ("stateful" , [False , True ])
3463+ def test_reset_white_to_move (self , stateful , include_pgn , include_fen ):
3464+ env = ChessEnv (
3465+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3466+ )
34023467 fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
34033468 td = env .reset (TensorDict ({"fen" : fen }))
34043469 assert td ["fen" ] == fen
3470+ if include_fen :
3471+ assert env .board .fen () == fen
34053472 assert td ["turn" ] == env .lib .WHITE
34063473 assert not td ["done" ]
34073474
3408- def test_reset_black_to_move (self , stateful ):
3409- env = ChessEnv (stateful = stateful )
3475+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3476+ @pytest .mark .parametrize ("stateful" , [False , True ])
3477+ def test_reset_black_to_move (self , stateful , include_pgn , include_fen ):
3478+ env = ChessEnv (
3479+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3480+ )
34103481 fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
34113482 td = env .reset (TensorDict ({"fen" : fen }))
34123483 assert td ["fen" ] == fen
3484+ assert env .board .fen () == fen
34133485 assert td ["turn" ] == env .lib .BLACK
34143486 assert not td ["done" ]
34153487
3416- def test_reset_done_error (self , stateful ):
3417- env = ChessEnv (stateful = stateful )
3488+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[True , False ], [True , True ]])
3489+ @pytest .mark .parametrize ("stateful" , [False , True ])
3490+ def test_reset_done_error (self , stateful , include_pgn , include_fen ):
3491+ env = ChessEnv (
3492+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3493+ )
34183494 fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
34193495 with pytest .raises (ValueError ) as e_info :
34203496 env .reset (TensorDict ({"fen" : fen }))
@@ -3425,12 +3501,19 @@ def test_reset_done_error(self, stateful):
34253501 @pytest .mark .parametrize (
34263502 "endstate" , ["white win" , "black win" , "stalemate" , "50 move" , "insufficient" ]
34273503 )
3428- def test_reward (self , stateful , reset_without_fen , endstate ):
3504+ @pytest .mark .parametrize ("include_pgn" , [False , True ])
3505+ @pytest .mark .parametrize ("include_fen" , [True ])
3506+ @pytest .mark .parametrize ("stateful" , [False , True ])
3507+ def test_reward (
3508+ self , stateful , reset_without_fen , endstate , include_pgn , include_fen
3509+ ):
34293510 if stateful and reset_without_fen :
34303511 # reset_without_fen is only used for stateless env
34313512 return
34323513
3433- env = ChessEnv (stateful = stateful )
3514+ env = ChessEnv (
3515+ stateful = stateful , include_pgn = include_pgn , include_fen = include_fen
3516+ )
34343517
34353518 if endstate == "white win" :
34363519 fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
@@ -3443,28 +3526,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
34433526 fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
34443527 expected_turn = env .lib .BLACK
34453528 move = "Rg1#"
3446- expected_reward = - 1
3529+ expected_reward = 1
34473530 expected_done = True
34483531
34493532 elif endstate == "stalemate" :
34503533 fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
34513534 expected_turn = env .lib .BLACK
34523535 move = "Rb7"
3453- expected_reward = 0
3536+ expected_reward = 0.5
34543537 expected_done = True
34553538
34563539 elif endstate == "insufficient" :
34573540 fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
34583541 expected_turn = env .lib .WHITE
34593542 move = "Kxd4"
3460- expected_reward = 0
3543+ expected_reward = 0.5
34613544 expected_done = True
34623545
34633546 elif endstate == "50 move" :
34643547 fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
34653548 expected_turn = env .lib .BLACK
34663549 move = "Kf7"
3467- expected_reward = 0
3550+ expected_reward = 0.5
34683551 expected_done = True
34693552
34703553 elif endstate == "not_done" :
@@ -3483,8 +3566,7 @@ def test_reward(self, stateful, reset_without_fen, endstate):
34833566 td = env .reset (TensorDict ({"fen" : fen }))
34843567 assert td ["turn" ] == expected_turn
34853568
3486- moves = env .get_legal_moves (None if stateful else td )
3487- td ["action" ] = moves .index (move )
3569+ td ["action" ] = env ._san_moves .index (move )
34883570 td = env .step (td )["next" ]
34893571 assert td ["done" ] == expected_done
34903572 assert td ["reward" ] == expected_reward
0 commit comments