44# LICENSE file in the root directory of this source tree.
55from __future__ import annotations
66
7+ import importlib .util
8+ import io
79from typing import Dict , Optional
810
911import torch
12+ from PIL import Image
1013from tensordict import TensorDict , TensorDictBase
1114from torchrl .data import Categorical , Composite , NonTensor , Unbounded
1215
1316from torchrl .envs import EnvBase
17+ from torchrl .envs .common import _EnvPostInit
1418
1519from torchrl .envs .utils import _classproperty
1620
1721
18- class ChessEnv (EnvBase ):
22+ class _HashMeta (_EnvPostInit ):
23+ def __call__ (cls , * args , ** kwargs ):
24+ instance = super ().__call__ (* args , ** kwargs )
25+ if kwargs .get ("include_hash" ):
26+ from torchrl .envs import Hash
27+
28+ in_keys = []
29+ out_keys = []
30+ if instance .include_san :
31+ in_keys .append ("san" )
32+ out_keys .append ("san_hash" )
33+ if instance .include_fen :
34+ in_keys .append ("fen" )
35+ out_keys .append ("fen_hash" )
36+ if instance .include_pgn :
37+ in_keys .append ("pgn" )
38+ out_keys .append ("pgn_hash" )
39+ return instance .append_transform (Hash (in_keys , out_keys ))
40+ return instance
41+
42+
43+ class ChessEnv (EnvBase , metaclass = _HashMeta ):
1944 """A chess environment that follows the TorchRL API.
2045
2146 Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__.
2247
2348 Args:
2449 stateful (bool): Whether to keep track of the internal state of the board.
2550 If False, the state will be stored in the observation and passed back
26- to the environment on each call. Default: ``False ``.
51+ to the environment on each call. Default: ``True ``.
2752
2853 .. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape.
2954 Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves,
@@ -90,28 +115,76 @@ class ChessEnv(EnvBase):
90115 """
91116
92117 _hash_table : Dict [int , str ] = {}
118+ _PNG_RESTART = """[Event "?"]
119+ [Site "?"]
120+ [Date "????.??.??"]
121+ [Round "?"]
122+ [White "?"]
123+ [Black "?"]
124+ [Result "*"]
125+
126+ *"""
93127
94128 @_classproperty
95129 def lib (cls ):
96130 try :
97131 import chess
132+ import chess .pgn
98133 except ImportError :
99134 raise ImportError (
100135 "The `chess` library could not be found. Make sure you installed it through `pip install chess`."
101136 )
102137 return chess
103138
104- def __init__ (self , stateful : bool = False ):
139+ def __init__ (
140+ self ,
141+ * ,
142+ stateful : bool = True ,
143+ include_san : bool = False ,
144+ include_fen : bool = False ,
145+ include_pgn : bool = False ,
146+ include_hash : bool = False ,
147+ pixels : bool = False ,
148+ ):
105149 chess = self .lib
106150 super ().__init__ ()
107151 self .full_observation_spec = Composite (
108- hashing = Unbounded (shape = (), dtype = torch .int64 ),
109- fen = NonTensor (shape = ()),
110152 turn = Categorical (n = 2 , dtype = torch .bool , shape = ()),
111153 )
154+ self .include_san = include_san
155+ self .include_fen = include_fen
156+ self .include_pgn = include_pgn
157+ if include_san :
158+ self .full_observation_spec ["san" ] = NonTensor (shape = (), example_data = "Nc6" )
159+ if include_pgn :
160+ self .full_observation_spec ["pgn" ] = NonTensor (
161+ shape = (), example_data = self ._PNG_RESTART
162+ )
163+ if include_fen :
164+ self .full_observation_spec ["fen" ] = NonTensor (shape = (), example_data = "any" )
165+ if not stateful and not (include_pgn or include_fen ):
166+ raise RuntimeError (
167+ "At least one state representation (pgn or fen) must be enabled when stateful "
168+ f"is { stateful } ."
169+ )
170+
112171 self .stateful = stateful
172+
113173 if not self .stateful :
114174 self .full_state_spec = self .full_observation_spec .clone ()
175+
176+ self .pixels = pixels
177+ if pixels :
178+ if importlib .util .find_spec ("cairosvg" ) is None :
179+ raise ImportError (
180+ "Please install cairosvg to use this environment with pixel rendering."
181+ )
182+ if importlib .util .find_spec ("torchvision" ) is None :
183+ raise ImportError (
184+ "Please install torchvision to use this environment with pixel rendering."
185+ )
186+ self .full_observation_spec ["pixels" ] = Unbounded (shape = ())
187+
115188 self .full_action_spec = Composite (
116189 action = Categorical (n = - 1 , shape = (), dtype = torch .int64 )
117190 )
@@ -132,41 +205,126 @@ def _is_done(self, board):
132205
133206 def _reset (self , tensordict = None ):
134207 fen = None
208+ pgn = None
135209 if tensordict is not None :
136- fen = self ._get_fen (tensordict ).data
137- dest = tensordict .empty ()
210+ if self .include_fen :
211+ fen = self ._get_fen (tensordict ).data
212+ dest = tensordict .empty ()
213+ if self .include_pgn :
214+ fen = self ._get_pgn (tensordict ).data
215+ dest = tensordict .empty ()
138216 else :
139217 dest = TensorDict ()
140218
141- if fen is None :
219+ if fen is None and pgn is None :
142220 self .board .reset ()
143- fen = self .board .fen ()
221+ if self .include_fen and fen is None :
222+ fen = self .board .fen ()
223+ if self .include_pgn and pgn is None :
224+ pgn = self ._PNG_RESTART
144225 else :
145- self .board .set_fen (fen )
146- if self ._is_done (self .board ):
147- raise ValueError (
148- "Cannot reset to a fen that is a gameover state." f" fen: { fen } "
149- )
150-
151- hashing = hash (fen )
226+ if fen is not None :
227+ self .board .set_fen (fen )
228+ if self ._is_done (self .board ):
229+ raise ValueError (
230+ "Cannot reset to a fen that is a gameover state." f" fen: { fen } "
231+ )
232+ elif pgn is not None :
233+ self .board = self ._pgn_to_board (pgn )
152234
153235 self ._set_action_space ()
154236 turn = self .board .turn
155- return dest .set ("fen" , fen ).set ("hashing" , hashing ).set ("turn" , turn )
237+ if self .include_san :
238+ dest .set ("san" , "[SAN][START]" )
239+ if self .include_fen :
240+ if fen is None :
241+ fen = self .board .fen ()
242+ dest .set ("fen" , fen )
243+ if self .include_pgn :
244+ if pgn is None :
245+ pgn = self ._board_to_pgn (self .board )
246+ dest .set ("pgn" , pgn )
247+ dest .set ("turn" , turn )
248+ if self .pixels :
249+ dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
250+ return dest
251+
252+ _cairosvg_lib = None
253+
254+ @_classproperty
255+ def _cairosvg (cls ):
256+ csvg = cls ._cairosvg_lib
257+ if csvg is None :
258+ import cairosvg
259+
260+ csvg = cls ._cairosvg_lib = cairosvg
261+ return csvg
262+
263+ _torchvision_lib = None
264+
265+ @_classproperty
266+ def _torchvision (cls ):
267+ tv = cls ._torchvision_lib
268+ if tv is None :
269+ import torchvision
270+
271+ tv = cls ._torchvision_lib = torchvision
272+ return tv
273+
274+ @classmethod
275+ def _get_tensor_image (cls , board ):
276+ try :
277+ svg = board ._repr_svg_ ()
278+ # Convert SVG to PNG using cairosvg
279+ png_data = io .BytesIO ()
280+ cls ._cairosvg .svg2png (bytestring = svg .encode ("utf-8" ), write_to = png_data )
281+ png_data .seek (0 )
282+ # Open the PNG image using Pillow
283+ img = Image .open (png_data )
284+ img = cls ._torchvision .transforms .functional .pil_to_tensor (img )
285+ except ImportError :
286+ raise ImportError (
287+ "Chess rendering requires cairosvg and torchvision to be installed."
288+ )
289+ return img
156290
157291 def _set_action_space (self , tensordict : TensorDict | None = None ):
158292 if not self .stateful and tensordict is not None :
159293 fen = self ._get_fen (tensordict ).data
160294 self .board .set_fen (fen )
161295 self .action_spec .set_provisional_n (self .board .legal_moves .count ())
162296
297+ @classmethod
298+ def _pgn_to_board (
299+ cls , pgn_string : str , board : "chess.Board" | None = None
300+ ) -> "chess.Board" :
301+ pgn_io = io .StringIO (pgn_string )
302+ game = cls .lib .pgn .read_game (pgn_io )
303+ if board is None :
304+ board = cls .Board ()
305+ else :
306+ board .reset ()
307+ for move in game .mainline_moves ():
308+ board .push (move )
309+ return board
310+
311+ @classmethod
312+ def _board_to_pgn (cls , board : "chess.Board" ) -> str :
313+ # Create a new Game object
314+ game = cls .lib .pgn .Game ()
315+
316+ # Add the moves to the game
317+ node = game
318+ for move in board .move_stack :
319+ node = node .add_variation (move )
320+
321+ # Generate the PGN string
322+ pgn_string = str (game )
323+ return pgn_string
324+
163325 @classmethod
164326 def _get_fen (cls , tensordict ):
165327 fen = tensordict .get ("fen" , None )
166- if fen is None :
167- hashing = tensordict .get ("hashing" , None )
168- if hashing is not None :
169- fen = cls ._hash_table .get (hashing .item ())
170328 return fen
171329
172330 def get_legal_moves (self , tensordict = None , uci = False ):
@@ -205,19 +363,40 @@ def _step(self, tensordict):
205363 # action
206364 action = tensordict .get ("action" )
207365 board = self .board
366+
208367 if not self .stateful :
209- fen = self ._get_fen (tensordict ).data
210- board .set_fen (fen )
368+ if self .include_fen :
369+ fen = self ._get_fen (tensordict ).data
370+ board .set_fen (fen )
371+ elif self .include_pgn :
372+ pgn = self ._get_pgn (tensordict ).data
373+ self ._pgn_to_board (pgn , board )
374+ else :
375+ raise RuntimeError (
376+ "Not enough information to deduce the board. If stateful=False, include_pgn or include_fen must be True."
377+ )
378+
211379 action = list (board .legal_moves )[action ]
380+ san = None
381+ if self .include_san :
382+ san = board .san (action )
212383 board .push (action )
384+
213385 self ._set_action_space ()
214386
215- # Collect data
216- fen = self .board .fen ()
217387 dest = tensordict .empty ()
218- hashing = hash (fen )
219- dest .set ("fen" , fen )
220- dest .set ("hashing" , hashing )
388+
389+ # Collect data
390+ if self .include_fen :
391+ fen = board .fen ()
392+ dest .set ("fen" , fen )
393+
394+ if self .include_pgn :
395+ pgn = self ._board_to_pgn (board )
396+ dest .set ("pgn" , pgn )
397+
398+ if san is not None :
399+ dest .set ("san" , san )
221400
222401 turn = torch .tensor (board .turn )
223402 if board .is_checkmate ():
@@ -226,12 +405,15 @@ def _step(self, tensordict):
226405 reward_val = 1 if winner == self .lib .WHITE else - 1
227406 else :
228407 reward_val = 0
408+
229409 reward = torch .tensor ([reward_val ], dtype = torch .int32 )
230410 done = self ._is_done (board )
231411 dest .set ("reward" , reward )
232412 dest .set ("turn" , turn )
233413 dest .set ("done" , [done ])
234414 dest .set ("terminated" , [done ])
415+ if self .pixels :
416+ dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
235417 return dest
236418
237419 def _set_seed (self , * args , ** kwargs ):
0 commit comments