@@ -76,19 +76,28 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
7676 being a subset of this space. The environment uses a mask to ensure only legal moves are selected.
7777
7878 Examples:
79+ >>> import torch
80+ >>> from torchrl.envs import ChessEnv
81+ >>> _ = torch.manual_seed(0)
7982 >>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True)
83+ >>> print(env)
84+ TransformedEnv(
85+ env=ChessEnv(),
86+ transform=ActionMask(keys=['action', 'action_mask']))
8087 >>> r = env.reset()
81- >>> env.rand_step(r)
88+ >>> print( env.rand_step(r) )
8289 TensorDict(
8390 fields={
8491 action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
92+ action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
8593 done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
8694 fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
8795 legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
8896 next: TensorDict(
8997 fields={
98+ action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
9099 done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
91- fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/1P6/P1PPPPPP /RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
100+ fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/5P2/8/PPPPP1PP /RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
92101 legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
93102 pgn: NonTensorData(data=[Event "?"]
94103 [Site "?"]
@@ -97,9 +106,10 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
97106 [White "?"]
98107 [Black "?"]
99108 [Result "*"]
100- 1. b3 *, batch_size=torch.Size([]), device=None),
109+
110+ 1. f4 *, batch_size=torch.Size([]), device=None),
101111 reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
102- san: NonTensorData(data=b3 , batch_size=torch.Size([]), device=None),
112+ san: NonTensorData(data=f4 , batch_size=torch.Size([]), device=None),
103113 terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
104114 turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
105115 batch_size=torch.Size([]),
@@ -112,56 +122,59 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
112122 [White "?"]
113123 [Black "?"]
114124 [Result "*"]
125+
115126 *, batch_size=torch.Size([]), device=None),
116- san: NonTensorData(data=[SAN][START] , batch_size=torch.Size([]), device=None),
127+ san: NonTensorData(data=<start> , batch_size=torch.Size([]), device=None),
117128 terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
118129 turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
119130 batch_size=torch.Size([]),
120131 device=None,
121132 is_shared=False)
122- >>> env.rollout(1000)
133+ >>> print( env.rollout(1000) )
123134 TensorDict(
124135 fields={
125- action: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.int64, is_shared=False),
126- done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
136+ action: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.int64, is_shared=False),
137+ action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
138+ done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
127139 fen: NonTensorStack(
128140 ['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
129- batch_size=torch.Size([352 ]),
141+ batch_size=torch.Size([96 ]),
130142 device=None),
131- legal_moves: Tensor(shape=torch.Size([352 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
143+ legal_moves: Tensor(shape=torch.Size([96 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
132144 next: TensorDict(
133145 fields={
134- done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
146+ action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
147+ done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
135148 fen: NonTensorStack(
136- ['rnbqkbnr/pppppppp/8/8/8/N7 /PPPPPPPP/R1BQKBNR b K ...,
137- batch_size=torch.Size([352 ]),
149+ ['rnbqkbnr/pppppppp/8/8/8/5N2 /PPPPPPPP/RNBQKB1R b ...,
150+ batch_size=torch.Size([96 ]),
138151 device=None),
139- legal_moves: Tensor(shape=torch.Size([352 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
152+ legal_moves: Tensor(shape=torch.Size([96 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
140153 pgn: NonTensorStack(
141154 ['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
142- batch_size=torch.Size([352 ]),
155+ batch_size=torch.Size([96 ]),
143156 device=None),
144- reward: Tensor(shape=torch.Size([352 , 1]), device=cpu, dtype=torch.float32, is_shared=False),
157+ reward: Tensor(shape=torch.Size([96 , 1]), device=cpu, dtype=torch.float32, is_shared=False),
145158 san: NonTensorStack(
146- ['Na3 ', 'a5 ', 'Nb1 ', 'Nc6 ', 'a3 ', 'g6 ', 'd4 ', 'd6' ...,
147- batch_size=torch.Size([352 ]),
159+ ['Nf3 ', 'Na6 ', 'c4 ', 'f6 ', 'h4 ', 'Rb8 ', 'Na3 ', 'Ra ...,
160+ batch_size=torch.Size([96 ]),
148161 device=None),
149- terminated: Tensor(shape=torch.Size([352 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
150- turn: Tensor(shape=torch.Size([352 ]), device=cpu, dtype=torch.bool, is_shared=False)},
151- batch_size=torch.Size([352 ]),
162+ terminated: Tensor(shape=torch.Size([96 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
163+ turn: Tensor(shape=torch.Size([96 ]), device=cpu, dtype=torch.bool, is_shared=False)},
164+ batch_size=torch.Size([96 ]),
152165 device=None,
153166 is_shared=False),
154167 pgn: NonTensorStack(
155168 ['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
156- batch_size=torch.Size([352 ]),
169+ batch_size=torch.Size([96 ]),
157170 device=None),
158171 san: NonTensorStack(
159- ['[SAN][START] ', 'Na3 ', 'a5 ', 'Nb1 ', 'Nc6 ', 'a3 ', ...,
160- batch_size=torch.Size([352 ]),
172+ ['<start> ', 'Nf3 ', 'Na6 ', 'c4 ', 'f6 ', 'h4 ', 'Rb8', ...,
173+ batch_size=torch.Size([96 ]),
161174 device=None),
162- terminated: Tensor(shape=torch.Size([352 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
163- turn: Tensor(shape=torch.Size([352 ]), device=cpu, dtype=torch.bool, is_shared=False)},
164- batch_size=torch.Size([352 ]),
175+ terminated: Tensor(shape=torch.Size([96 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
176+ turn: Tensor(shape=torch.Size([96 ]), device=cpu, dtype=torch.bool, is_shared=False)},
177+ batch_size=torch.Size([96 ]),
165178 device=None,
166179 is_shared=False)
167180 """ # noqa: D301
@@ -225,13 +238,15 @@ def _legal_moves_to_index(
225238 [self ._san_moves .index (board .san (m )) for m in board .legal_moves ],
226239 dtype = torch .int64 ,
227240 )
228-
241+ mask = None
229242 if return_mask :
230- return self ._move_index_to_mask (indices )
243+ mask = self ._move_index_to_mask (indices )
231244 if pad :
232245 indices = torch .nn .functional .pad (
233246 indices , [0 , 218 - indices .numel () + 1 ], value = len (self .san_moves )
234247 )
248+ if return_mask :
249+ return indices , mask
235250 return indices
236251
237252 @classmethod
@@ -369,16 +384,19 @@ def _reset(self, tensordict=None):
369384 dest .set ("pgn" , pgn )
370385 dest .set ("turn" , turn )
371386 if self .include_legal_moves :
372- moves_idx = self ._legal_moves_to_index (board = self .board , pad = True )
373- dest .set ("legal_moves" , moves_idx )
387+ moves_idx = self ._legal_moves_to_index (
388+ board = self .board , pad = True , return_mask = self .mask_actions
389+ )
374390 if self .mask_actions :
375- dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
391+ moves_idx , mask = moves_idx
392+ dest .set ("action_mask" , mask )
393+ dest .set ("legal_moves" , moves_idx )
376394 elif self .mask_actions :
377395 dest .set (
378396 "action_mask" ,
379397 self ._legal_moves_to_index (
380398 board = self .board , pad = True , return_mask = True
381- ),
399+ )[ 1 ] ,
382400 )
383401
384402 if self .pixels :
@@ -525,16 +543,19 @@ def _step(self, tensordict):
525543 dest .set ("san" , san )
526544
527545 if self .include_legal_moves :
528- moves_idx = self ._legal_moves_to_index (board = board , pad = True )
529- dest .set ("legal_moves" , moves_idx )
546+ moves_idx = self ._legal_moves_to_index (
547+ board = board , pad = True , return_mask = self .mask_actions
548+ )
530549 if self .mask_actions :
531- dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
550+ moves_idx , mask = moves_idx
551+ dest .set ("action_mask" , mask )
552+ dest .set ("legal_moves" , moves_idx )
532553 elif self .mask_actions :
533554 dest .set (
534555 "action_mask" ,
535556 self ._legal_moves_to_index (
536557 board = self .board , pad = True , return_mask = True
537- ),
558+ )[ 1 ] ,
538559 )
539560
540561 turn = torch .tensor (board .turn )
0 commit comments