77import importlib .util
88import io
99import pathlib
10- from typing import Dict , Optional
10+ from typing import Dict
1111
1212import torch
1313from PIL import Image
1414from tensordict import TensorDict , TensorDictBase
15- from torchrl .data import Bounded , Categorical , Composite , NonTensor , Unbounded
15+ from torchrl .data import Binary , Bounded , Categorical , Composite , NonTensor , Unbounded
1616
1717from torchrl .envs import EnvBase
1818from torchrl .envs .common import _EnvPostInit
1919
2020from torchrl .envs .utils import _classproperty
2121
2222
23- class _HashMeta (_EnvPostInit ):
23+ class _ChessMeta (_EnvPostInit ):
2424 def __call__ (cls , * args , ** kwargs ):
2525 instance = super ().__call__ (* args , ** kwargs )
2626 if kwargs .get ("include_hash" ):
@@ -37,11 +37,15 @@ def __call__(cls, *args, **kwargs):
3737 if instance .include_pgn :
3838 in_keys .append ("pgn" )
3939 out_keys .append ("pgn_hash" )
40- return instance .append_transform (Hash (in_keys , out_keys ))
40+ instance = instance .append_transform (Hash (in_keys , out_keys ))
41+ if kwargs .get ("mask_actions" , True ):
42+ from torchrl .envs import ActionMask
43+
44+ instance = instance .append_transform (ActionMask ())
4145 return instance
4246
4347
44- class ChessEnv (EnvBase , metaclass = _HashMeta ):
48+ class ChessEnv (EnvBase , metaclass = _ChessMeta ):
4549 r"""A chess environment that follows the TorchRL API.
4650
4751 This environment simulates a chess game using the `chess` library. It supports various state representations
@@ -63,6 +67,8 @@ class ChessEnv(EnvBase, metaclass=_HashMeta):
6367 include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``.
6468 include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``.
6569 include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``.
70+ mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended
71+ to the env to make sure that the actions are properly masked. Default: ``True``.
6672 pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``.
6773
6874 .. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves.
@@ -200,16 +206,15 @@ def _legal_moves_to_index(
200206 ) -> torch .Tensor :
201207 if not self .stateful :
202208 if tensordict is None :
203- raise RuntimeError (
204- "rand_action requires a tensordict when stateful is False."
205- )
206- if self .include_fen :
207- fen = self ._get_fen (tensordict )
209+ # trust the board
210+ pass
211+ elif self .include_fen :
212+ fen = tensordict .get ("fen" , None )
208213 fen = fen .data
209214 self .board .set_fen (fen )
210215 board = self .board
211216 elif self .include_pgn :
212- pgn = self . _get_pgn ( tensordict )
217+ pgn = tensordict . get ( "pgn" )
213218 pgn = pgn .data
214219 board = self ._pgn_to_board (pgn , self .board )
215220
@@ -222,15 +227,19 @@ def _legal_moves_to_index(
222227 )
223228
224229 if return_mask :
225- return torch .zeros (len (self .san_moves ), dtype = torch .bool ).index_fill_ (
226- 0 , indices , True
227- )
230+ return self ._move_index_to_mask (indices )
228231 if pad :
229232 indices = torch .nn .functional .pad (
230233 indices , [0 , 218 - indices .numel () + 1 ], value = len (self .san_moves )
231234 )
232235 return indices
233236
237+ @classmethod
238+ def _move_index_to_mask (cls , indices : torch .Tensor ) -> torch .Tensor :
239+ return torch .zeros (len (cls .san_moves ), dtype = torch .bool ).index_fill_ (
240+ 0 , indices , True
241+ )
242+
234243 def __init__ (
235244 self ,
236245 * ,
@@ -240,6 +249,7 @@ def __init__(
240249 include_pgn : bool = False ,
241250 include_legal_moves : bool = False ,
242251 include_hash : bool = False ,
252+ mask_actions : bool = True ,
243253 pixels : bool = False ,
244254 ):
245255 chess = self .lib
@@ -250,6 +260,7 @@ def __init__(
250260 self .include_san = include_san
251261 self .include_fen = include_fen
252262 self .include_pgn = include_pgn
263+ self .mask_actions = mask_actions
253264 self .include_legal_moves = include_legal_moves
254265 if include_legal_moves :
255266 # 218 max possible legal moves per chess board position
@@ -274,8 +285,10 @@ def __init__(
274285
275286 self .stateful = stateful
276287
277- if not self .stateful :
278- self .full_state_spec = self .full_observation_spec .clone ()
288+ # state_spec is loosely defined as such - it's not really an issue that extra keys
289+ # can go missing but it allows us to reset the env using fen passed to the reset
290+ # method.
291+ self .full_state_spec = self .full_observation_spec .clone ()
279292
280293 self .pixels = pixels
281294 if pixels :
@@ -295,16 +308,16 @@ def __init__(
295308 self .full_reward_spec = Composite (
296309 reward = Unbounded (shape = (1 ,), dtype = torch .float32 )
297310 )
311+ if self .mask_actions :
312+ self .full_observation_spec ["action_mask" ] = Binary (
313+ n = len (self .san_moves ), dtype = torch .bool
314+ )
315+
298316 # done spec generated automatically
299317 self .board = chess .Board ()
300318 if self .stateful :
301319 self .action_spec .set_provisional_n (len (list (self .board .legal_moves )))
302320
303- def rand_action (self , tensordict : Optional [TensorDictBase ] = None ):
304- mask = self ._legal_moves_to_index (tensordict , return_mask = True )
305- self .action_spec .update_mask (mask )
306- return super ().rand_action (tensordict )
307-
308321 def _is_done (self , board ):
309322 return board .is_game_over () | board .is_fifty_moves ()
310323
@@ -314,11 +327,11 @@ def _reset(self, tensordict=None):
314327 if tensordict is not None :
315328 dest = tensordict .empty ()
316329 if self .include_fen :
317- fen = self . _get_fen ( tensordict )
330+ fen = tensordict . get ( "fen" , None )
318331 if fen is not None :
319332 fen = fen .data
320333 elif self .include_pgn :
321- pgn = self . _get_pgn ( tensordict )
334+ pgn = tensordict . get ( "pgn" , None )
322335 if pgn is not None :
323336 pgn = pgn .data
324337 else :
@@ -358,13 +371,18 @@ def _reset(self, tensordict=None):
358371 if self .include_legal_moves :
359372 moves_idx = self ._legal_moves_to_index (board = self .board , pad = True )
360373 dest .set ("legal_moves" , moves_idx )
374+ if self .mask_actions :
375+ dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
376+ elif self .mask_actions :
377+ dest .set (
378+ "action_mask" ,
379+ self ._legal_moves_to_index (
380+ board = self .board , pad = True , return_mask = True
381+ ),
382+ )
383+
361384 if self .pixels :
362385 dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
363-
364- if self .stateful :
365- mask = self ._legal_moves_to_index (dest , return_mask = True )
366- self .action_spec .update_mask (mask )
367-
368386 return dest
369387
370388 _cairosvg_lib = None
@@ -435,16 +453,6 @@ def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821
435453 pgn_string = str (game )
436454 return pgn_string
437455
438- @classmethod
439- def _get_fen (cls , tensordict ):
440- fen = tensordict .get ("fen" , None )
441- return fen
442-
443- @classmethod
444- def _get_pgn (cls , tensordict ):
445- pgn = tensordict .get ("pgn" , None )
446- return pgn
447-
448456 def get_legal_moves (self , tensordict = None , uci = False ):
449457 """List the legal moves in a position.
450458
@@ -468,7 +476,7 @@ def get_legal_moves(self, tensordict=None, uci=False):
468476 raise ValueError (
469477 "tensordict must be given since this env is not stateful"
470478 )
471- fen = self . _get_fen ( tensordict ).data
479+ fen = tensordict . get ( "fen" ).data
472480 board .set_fen (fen )
473481 moves = board .legal_moves
474482
@@ -486,10 +494,10 @@ def _step(self, tensordict):
486494 fen = None
487495 if not self .stateful :
488496 if self .include_fen :
489- fen = self . _get_fen ( tensordict ).data
497+ fen = tensordict . get ( "fen" ).data
490498 board .set_fen (fen )
491499 elif self .include_pgn :
492- pgn = self . _get_pgn ( tensordict ).data
500+ pgn = tensordict . get ( "pgn" ).data
493501 board = self ._pgn_to_board (pgn , board )
494502 else :
495503 raise RuntimeError (
@@ -519,6 +527,15 @@ def _step(self, tensordict):
519527 if self .include_legal_moves :
520528 moves_idx = self ._legal_moves_to_index (board = board , pad = True )
521529 dest .set ("legal_moves" , moves_idx )
530+ if self .mask_actions :
531+ dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
532+ elif self .mask_actions :
533+ dest .set (
534+ "action_mask" ,
535+ self ._legal_moves_to_index (
536+ board = self .board , pad = True , return_mask = True
537+ ),
538+ )
522539
523540 turn = torch .tensor (board .turn )
524541 done = self ._is_done (board )
@@ -538,11 +555,6 @@ def _step(self, tensordict):
538555 dest .set ("terminated" , [done ])
539556 if self .pixels :
540557 dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
541-
542- if self .stateful :
543- mask = self ._legal_moves_to_index (dest , return_mask = True )
544- self .action_spec .update_mask (mask )
545-
546558 return dest
547559
548560 def _set_seed (self , * args , ** kwargs ):
0 commit comments