55import tensordict .nn
66import torch
77import tqdm
8- from tensordict .nn import TensorDictSequential as TDSeq , TensorDictModule as TDMod , \
9- ProbabilisticTensorDictModule as TDProb , ProbabilisticTensorDictSequential as TDProbSeq
8+ from tensordict .nn import (
9+ ProbabilisticTensorDictModule as TDProb ,
10+ ProbabilisticTensorDictSequential as TDProbSeq ,
11+ TensorDictModule as TDMod ,
12+ TensorDictSequential as TDSeq ,
13+ )
1014from torch import nn
1115from torch .nn .utils import clip_grad_norm_
1216from torch .optim import Adam
1317
1418from torchrl .collectors import SyncDataCollector
19+ from torchrl .data import LazyTensorStorage , ReplayBuffer , SamplerWithoutReplacement
1520
1621from torchrl .envs import ChessEnv , Tokenizer
1722from torchrl .modules import MLP
1823from torchrl .modules .distributions import MaskedCategorical
1924from torchrl .objectives import ClipPPOLoss
2025from torchrl .objectives .value import GAE
21- from torchrl .data import ReplayBuffer , LazyTensorStorage , SamplerWithoutReplacement
2226
2327tensordict .nn .set_composite_lp_aggregate (False )
2428
3943embedding_moves = nn .Embedding (num_embeddings = n + 1 , embedding_dim = 64 )
4044
4145# Embedding for the fen
42- embedding_fen = nn .Embedding (num_embeddings = transform .tokenizer .vocab_size , embedding_dim = 64 )
46+ embedding_fen = nn .Embedding (
47+ num_embeddings = transform .tokenizer .vocab_size , embedding_dim = 64
48+ )
4349
4450backbone = MLP (out_features = 512 , num_cells = [512 ] * 8 , activation_class = nn .ReLU )
4551
4955critic_head = nn .Linear (512 , 1 )
5056critic_head .bias .data .fill_ (0 )
5157
52- prob = TDProb (in_keys = ["logits" , "mask" ], out_keys = ["action" ], distribution_class = MaskedCategorical , return_log_prob = True )
58+ prob = TDProb (
59+ in_keys = ["logits" , "mask" ],
60+ out_keys = ["action" ],
61+ distribution_class = MaskedCategorical ,
62+ return_log_prob = True ,
63+ )
64+
5365
5466def make_mask (idx ):
5567 mask = idx .new_zeros ((* idx .shape [:- 1 ], n + 1 ), dtype = torch .bool )
5668 return mask .scatter_ (- 1 , idx , torch .ones_like (idx , dtype = torch .bool ))[..., :- 1 ]
5769
70+
5871actor = TDProbSeq (
59- TDMod (
60- make_mask ,
61- in_keys = ["legal_moves" ], out_keys = ["mask" ]),
72+ TDMod (make_mask , in_keys = ["legal_moves" ], out_keys = ["mask" ]),
6273 TDMod (embedding_moves , in_keys = ["legal_moves" ], out_keys = ["embedded_legal_moves" ]),
6374 TDMod (embedding_fen , in_keys = ["fen_tokenized" ], out_keys = ["embedded_fen" ]),
64- TDMod (lambda * args : torch .cat ([arg .view (* arg .shape [:- 2 ], - 1 ) for arg in args ], dim = - 1 ), in_keys = ["embedded_legal_moves" , "embedded_fen" ],
65- out_keys = ["features" ]),
75+ TDMod (
76+ lambda * args : torch .cat (
77+ [arg .view (* arg .shape [:- 2 ], - 1 ) for arg in args ], dim = - 1
78+ ),
79+ in_keys = ["embedded_legal_moves" , "embedded_fen" ],
80+ out_keys = ["features" ],
81+ ),
6682 TDMod (backbone , in_keys = ["features" ], out_keys = ["hidden" ]),
6783 TDMod (actor_head , in_keys = ["hidden" ], out_keys = ["logits" ]),
6884 prob ,
@@ -78,7 +94,9 @@ def make_mask(idx):
7894
7995optim = Adam (loss .parameters ())
8096
81- gae = GAE (value_network = TDSeq (* actor [:- 2 ], critic ), gamma = 0.99 , lmbda = 0.95 , shifted = True )
97+ gae = GAE (
98+ value_network = TDSeq (* actor [:- 2 ], critic ), gamma = 0.99 , lmbda = 0.95 , shifted = True
99+ )
82100
83101# Create a data collector
84102collector = SyncDataCollector (
@@ -88,12 +106,20 @@ def make_mask(idx):
88106 total_frames = 1_000_000 ,
89107)
90108
91- replay_buffer0 = ReplayBuffer (storage = LazyTensorStorage (max_size = collector .frames_per_batch // 2 ), batch_size = batch_size , sampler = SamplerWithoutReplacement ())
92- replay_buffer1 = ReplayBuffer (storage = LazyTensorStorage (max_size = collector .frames_per_batch // 2 ), batch_size = batch_size , sampler = SamplerWithoutReplacement ())
109+ replay_buffer0 = ReplayBuffer (
110+ storage = LazyTensorStorage (max_size = collector .frames_per_batch // 2 ),
111+ batch_size = batch_size ,
112+ sampler = SamplerWithoutReplacement (),
113+ )
114+ replay_buffer1 = ReplayBuffer (
115+ storage = LazyTensorStorage (max_size = collector .frames_per_batch // 2 ),
116+ batch_size = batch_size ,
117+ sampler = SamplerWithoutReplacement (),
118+ )
93119
94120for data in tqdm .tqdm (collector ):
95121 data = data .filter_non_tensor_data ()
96- print (' data' , data [0 ::2 ])
122+ print (" data" , data [0 ::2 ])
97123 for i in range (num_epochs ):
98124 replay_buffer0 .empty ()
99125 replay_buffer1 .empty ()
@@ -103,14 +129,24 @@ def make_mask(idx):
103129 # player 1
104130 data1 = gae (data [1 ::2 ])
105131 if i == 0 :
106- print ('win rate for 0' , data0 ["next" , "reward" ].sum () / data ["next" , "done" ].sum ().clamp_min (1e-6 ))
107- print ('win rate for 1' , data1 ["next" , "reward" ].sum () / data ["next" , "done" ].sum ().clamp_min (1e-6 ))
132+ print (
133+ "win rate for 0" ,
134+ data0 ["next" , "reward" ].sum ()
135+ / data ["next" , "done" ].sum ().clamp_min (1e-6 ),
136+ )
137+ print (
138+ "win rate for 1" ,
139+ data1 ["next" , "reward" ].sum ()
140+ / data ["next" , "done" ].sum ().clamp_min (1e-6 ),
141+ )
108142
109143 replay_buffer0 .extend (data0 )
110144 replay_buffer1 .extend (data1 )
111145
112- n_iter = collector .frames_per_batch // (2 * batch_size )
113- for (d0 , d1 ) in tqdm .tqdm (zip (replay_buffer0 , replay_buffer1 , strict = True ), total = n_iter ):
146+ n_iter = collector .frames_per_batch // (2 * batch_size )
147+ for (d0 , d1 ) in tqdm .tqdm (
148+ zip (replay_buffer0 , replay_buffer1 , strict = True ), total = n_iter
149+ ):
114150 loss_vals = (loss (d0 ) + loss (d1 )) / 2
115151 loss_vals .sum (reduce = True ).backward ()
116152 gn = clip_grad_norm_ (loss .parameters (), 100.0 )
0 commit comments