|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | import argparse |
8 | | -import contextlib |
9 | 8 | import importlib.util |
10 | 9 | import random |
11 | 10 | import re |
|
14 | 13 |
|
15 | 14 | import pytest |
16 | 15 | import torch |
17 | | -from mocking_classes_llm import DummyStrDataLoader, DummyTensorDataLoader |
18 | | - |
19 | | -from tensordict import ( |
20 | | - lazy_stack, |
21 | | - NonTensorData, |
22 | | - NonTensorStack, |
23 | | - set_capture_non_tensor_stack, |
24 | | - set_list_to_stack, |
25 | | - TensorDict, |
26 | | -) |
| 16 | + |
| 17 | +from tensordict import lazy_stack, set_list_to_stack, TensorDict |
27 | 18 |
|
28 | 19 | from torchrl._utils import logger as torchrl_logger |
29 | 20 | from torchrl.data.llm.history import History |
30 | | -from torchrl.envs import StepCounter |
31 | 21 | from torchrl.envs.llm import ( |
32 | | - as_padded_tensor, |
33 | 22 | ChatEnv, |
34 | | - DataLoadingPrimer, |
35 | 23 | GSM8KEnv, |
36 | 24 | KLRewardTransform, |
37 | | - LLMEnv, |
38 | 25 | make_gsm8k_env, |
39 | 26 | RetrieveKL, |
40 | 27 | ) |
@@ -82,353 +69,6 @@ def set_list_to_stack_for_test(): |
82 | 69 | return |
83 | 70 |
|
84 | 71 |
|
85 | | -class TestLLMEnv: |
86 | | - @pytest.fixture(scope="class", autouse=True) |
87 | | - def set_capture(self): |
88 | | - with set_capture_non_tensor_stack(False): |
89 | | - yield None |
90 | | - return |
91 | | - |
92 | | - @pytest.mark.skipif(not _has_transformers, reason="test requires transformers") |
93 | | - @pytest.mark.parametrize( |
94 | | - "from_text,stack_method", |
95 | | - [ |
96 | | - [True, None], |
97 | | - [False, "as_padded_tensor"], |
98 | | - # TODO: a bit experimental, fails with check_env_specs |
99 | | - # [False, "as_nested_tensor"], |
100 | | - [False, None], |
101 | | - ], |
102 | | - ) |
103 | | - @pytest.mark.parametrize("dl_batch_size", [1, 4]) |
104 | | - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) |
105 | | - @pytest.mark.parametrize("device", [None, "cpu"]) |
106 | | - def test_llm_env( |
107 | | - self, from_text, stack_method, device, dl_batch_size, env_batch_size |
108 | | - ): |
109 | | - if from_text: |
110 | | - primer = DataLoadingPrimer( |
111 | | - dataloader=DummyStrDataLoader(batch_size=dl_batch_size), |
112 | | - batch_size=env_batch_size, |
113 | | - ) |
114 | | - else: |
115 | | - if stack_method is None: |
116 | | - stack_method = as_padded_tensor |
117 | | - primer = DataLoadingPrimer( |
118 | | - dataloader=DummyTensorDataLoader( |
119 | | - batch_size=dl_batch_size, padding=True |
120 | | - ), |
121 | | - stack_method=stack_method, |
122 | | - batch_size=env_batch_size, |
123 | | - ) |
124 | | - with pytest.warns(UserWarning, match="eos_token_id"): |
125 | | - env = LLMEnv( |
126 | | - from_text=from_text, |
127 | | - device=device, |
128 | | - batch_size=primer.batch_size, |
129 | | - ) |
130 | | - env = env.append_transform(primer) |
131 | | - if env_batch_size is None: |
132 | | - assert env.batch_size == torch.Size((dl_batch_size,)) |
133 | | - else: |
134 | | - if not isinstance(env_batch_size, tuple): |
135 | | - env_batch_size = ( |
136 | | - torch.Size(()) |
137 | | - if env_batch_size == 0 |
138 | | - else torch.Size((env_batch_size,)) |
139 | | - ) |
140 | | - assert env.batch_size == env_batch_size |
141 | | - |
142 | | - env.check_env_specs(break_when_any_done="both") |
143 | | - |
144 | | - @pytest.mark.skipif(not _has_transformers, reason="test requires transformers") |
145 | | - @pytest.mark.parametrize("tokenizer", [True, False]) |
146 | | - @pytest.mark.parametrize( |
147 | | - "from_text,stack_method", |
148 | | - [ |
149 | | - [True, None], |
150 | | - [False, "as_padded_tensor"], |
151 | | - [False, None], |
152 | | - ], |
153 | | - ) |
154 | | - @pytest.mark.parametrize("device", [None, "cpu"]) |
155 | | - @pytest.mark.parametrize("dl_batch_size", [1, 4]) |
156 | | - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) |
157 | | - def test_llm_from_dataloader( |
158 | | - self, |
159 | | - from_text, |
160 | | - stack_method, |
161 | | - device, |
162 | | - dl_batch_size, |
163 | | - env_batch_size, |
164 | | - tokenizer, |
165 | | - ): |
166 | | - from transformers import AutoTokenizer |
167 | | - |
168 | | - if tokenizer and from_text: |
169 | | - tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
170 | | - else: |
171 | | - tokenizer = None |
172 | | - if from_text: |
173 | | - kwargs = { |
174 | | - "dataloader": DummyStrDataLoader(batch_size=dl_batch_size), |
175 | | - } |
176 | | - else: |
177 | | - if stack_method is None: |
178 | | - stack_method = as_padded_tensor |
179 | | - kwargs = { |
180 | | - "dataloader": DummyTensorDataLoader( |
181 | | - padding=True, batch_size=dl_batch_size |
182 | | - ), |
183 | | - "stack_method": stack_method, |
184 | | - } |
185 | | - kwargs.update( |
186 | | - { |
187 | | - "batch_size": env_batch_size, |
188 | | - "from_text": from_text, |
189 | | - "device": device, |
190 | | - "has_attention": False, |
191 | | - "tokenizer": tokenizer, |
192 | | - } |
193 | | - ) |
194 | | - with pytest.warns(UserWarning, match="eos_token_id"): |
195 | | - env = LLMEnv.from_dataloader(**kwargs) |
196 | | - if env_batch_size is None: |
197 | | - assert env.batch_size == torch.Size((dl_batch_size,)) |
198 | | - else: |
199 | | - if not isinstance(env_batch_size, tuple): |
200 | | - env_batch_size = ( |
201 | | - torch.Size(()) |
202 | | - if env_batch_size == 0 |
203 | | - else torch.Size((env_batch_size,)) |
204 | | - ) |
205 | | - assert env.batch_size == env_batch_size |
206 | | - env.check_env_specs(break_when_any_done="both") |
207 | | - |
208 | | - def policy(td): |
209 | | - if from_text and tokenizer is None: |
210 | | - if not td.shape: |
211 | | - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorData( |
212 | | - "<nothing>", device=device |
213 | | - ) |
214 | | - else: |
215 | | - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack( |
216 | | - *[ |
217 | | - NonTensorData("<nothing>", device=device) |
218 | | - for _ in range(td.shape[0]) |
219 | | - ] |
220 | | - ) |
221 | | - else: |
222 | | - td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( |
223 | | - td.shape + (1,), dtype=torch.int64 |
224 | | - ) |
225 | | - return td |
226 | | - |
227 | | - r = env.rollout(10, policy) |
228 | | - if env.batch_size == (): |
229 | | - assert r.ndim == 1 |
230 | | - r = r.unsqueeze(0) |
231 | | - else: |
232 | | - assert r.ndim == 2 |
233 | | - if from_text and tokenizer is None: |
234 | | - assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str) |
235 | | - assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str) |
236 | | - assert ( |
237 | | - r[0, 0][LLMEnv._DEFAULT_STR_KEY] |
238 | | - == r[0, 1][LLMEnv._DEFAULT_STR_KEY][ |
239 | | - : -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) |
240 | | - ] |
241 | | - ), ( |
242 | | - r[0, 0][LLMEnv._DEFAULT_STR_KEY], |
243 | | - r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY], |
244 | | - r[0, 0]["next", LLMEnv._DEFAULT_STR_KEY], |
245 | | - r[0, 1][LLMEnv._DEFAULT_STR_KEY], |
246 | | - ) |
247 | | - assert ( |
248 | | - r[0, 1][LLMEnv._DEFAULT_STR_KEY] |
249 | | - == r[0, 2][LLMEnv._DEFAULT_STR_KEY][ |
250 | | - : -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) |
251 | | - ] |
252 | | - ) |
253 | | - assert ( |
254 | | - r[-1, 0][LLMEnv._DEFAULT_STR_KEY] |
255 | | - == r[-1, 1][LLMEnv._DEFAULT_STR_KEY][ |
256 | | - : -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) |
257 | | - ] |
258 | | - ) |
259 | | - assert ( |
260 | | - r[-1, 1][LLMEnv._DEFAULT_STR_KEY] |
261 | | - == r[-1, 2][LLMEnv._DEFAULT_STR_KEY][ |
262 | | - : -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) |
263 | | - ] |
264 | | - ) |
265 | | - elif tokenizer is None: |
266 | | - assert ( |
267 | | - r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] |
268 | | - == r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1] |
269 | | - ).all() |
270 | | - assert ( |
271 | | - r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY] |
272 | | - == r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1] |
273 | | - ).all() |
274 | | - assert ( |
275 | | - r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY] |
276 | | - == r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1] |
277 | | - ).all() |
278 | | - assert ( |
279 | | - r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY] |
280 | | - == r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1] |
281 | | - ).all() |
282 | | - |
283 | | - @pytest.mark.parametrize( |
284 | | - "from_text,stack_method", |
285 | | - [ |
286 | | - [True, None], |
287 | | - [False, "as_padded_tensor"], |
288 | | - # TODO: a bit experimental, fails with check_env_specs |
289 | | - # [False, "as_nested_tensor"], |
290 | | - [False, None], |
291 | | - ], |
292 | | - ) |
293 | | - @pytest.mark.parametrize("device", [None, "cpu"]) |
294 | | - @pytest.mark.parametrize("dl_batch_size", [1, 4]) |
295 | | - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) |
296 | | - @pytest.mark.parametrize("repeats", [3]) |
297 | | - def test_llm_from_dataloader_repeats( |
298 | | - self, from_text, stack_method, device, env_batch_size, dl_batch_size, repeats |
299 | | - ): |
300 | | - if from_text: |
301 | | - kwargs = { |
302 | | - "dataloader": DummyStrDataLoader(batch_size=dl_batch_size), |
303 | | - "repeats": repeats, |
304 | | - } |
305 | | - else: |
306 | | - if stack_method is None: |
307 | | - stack_method = as_padded_tensor |
308 | | - kwargs = { |
309 | | - "dataloader": DummyTensorDataLoader( |
310 | | - padding=True, batch_size=dl_batch_size |
311 | | - ), |
312 | | - "stack_method": stack_method, |
313 | | - "repeats": repeats, |
314 | | - } |
315 | | - kwargs.update( |
316 | | - { |
317 | | - "batch_size": env_batch_size, |
318 | | - "from_text": from_text, |
319 | | - "device": device, |
320 | | - "has_attention": False, |
321 | | - } |
322 | | - ) |
323 | | - with pytest.warns(UserWarning, match="eos_token_id"): |
324 | | - env = LLMEnv.from_dataloader(**kwargs) |
325 | | - assert env.transform.repeats == repeats |
326 | | - |
327 | | - max_steps = 3 |
328 | | - env.append_transform(StepCounter(max_steps=max_steps)) |
329 | | - |
330 | | - def policy(td): |
331 | | - if from_text: |
332 | | - if not td.shape: |
333 | | - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>" |
334 | | - else: |
335 | | - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack( |
336 | | - *["<nothing>" for _ in range(td.shape[0])] |
337 | | - ) |
338 | | - else: |
339 | | - td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( |
340 | | - td.shape + (1,), dtype=torch.int64 |
341 | | - ) |
342 | | - return td |
343 | | - |
344 | | - r = env.rollout(100, policy, break_when_any_done=False) |
345 | | - # check that r at reset is always the same |
346 | | - r_reset = r[..., ::max_steps] |
347 | | - if from_text: |
348 | | - all_strings = r_reset.view(-1)[LLMEnv._DEFAULT_STR_KEY] |
349 | | - assert sum(s == all_strings[0] for s in all_strings) == repeats |
350 | | - assert sum(s == all_strings[repeats] for s in all_strings) == repeats |
351 | | - assert sum(s == all_strings[repeats * 2] for s in all_strings) == repeats |
352 | | - else: |
353 | | - all_tokens = r_reset.view(-1)[LLMEnv._DEFAULT_TOKEN_KEY] |
354 | | - assert sum((s == all_tokens[0]).all() for s in all_tokens) == repeats |
355 | | - assert sum((s == all_tokens[repeats]).all() for s in all_tokens) == repeats |
356 | | - assert ( |
357 | | - sum((s == all_tokens[repeats * 2]).all() for s in all_tokens) == repeats |
358 | | - ) |
359 | | - |
360 | | - @pytest.mark.parametrize( |
361 | | - "from_text,stack_method", |
362 | | - [ |
363 | | - [True, None], |
364 | | - [False, "as_padded_tensor"], |
365 | | - ], |
366 | | - ) |
367 | | - @pytest.mark.parametrize("device", [None]) |
368 | | - @pytest.mark.parametrize("dl_batch_size", [1, 4]) |
369 | | - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) |
370 | | - @pytest.mark.parametrize("repeats", [3]) |
371 | | - @pytest.mark.parametrize( |
372 | | - "assign_reward,assign_done", [[True, False], [True, True], [False, True]] |
373 | | - ) |
374 | | - def test_done_and_reward( |
375 | | - self, |
376 | | - from_text, |
377 | | - stack_method, |
378 | | - device, |
379 | | - env_batch_size, |
380 | | - dl_batch_size, |
381 | | - repeats, |
382 | | - assign_reward, |
383 | | - assign_done, |
384 | | - ): |
385 | | - with pytest.raises( |
386 | | - ValueError, match="from_text" |
387 | | - ) if from_text else contextlib.nullcontext(): |
388 | | - if from_text: |
389 | | - kwargs = { |
390 | | - "dataloader": DummyStrDataLoader(batch_size=dl_batch_size), |
391 | | - "repeats": repeats, |
392 | | - "assign_reward": assign_reward, |
393 | | - "assign_done": assign_done, |
394 | | - } |
395 | | - else: |
396 | | - if stack_method is None: |
397 | | - stack_method = as_padded_tensor |
398 | | - kwargs = { |
399 | | - "dataloader": DummyTensorDataLoader( |
400 | | - padding=True, batch_size=dl_batch_size |
401 | | - ), |
402 | | - "stack_method": stack_method, |
403 | | - "repeats": repeats, |
404 | | - "assign_reward": assign_reward, |
405 | | - "assign_done": assign_done, |
406 | | - } |
407 | | - kwargs.update( |
408 | | - { |
409 | | - "batch_size": env_batch_size, |
410 | | - "from_text": from_text, |
411 | | - "device": device, |
412 | | - "has_attention": False, |
413 | | - } |
414 | | - ) |
415 | | - with pytest.warns(UserWarning, match="eos_token_id"): |
416 | | - env = LLMEnv.from_dataloader(**kwargs) |
417 | | - # We want to make sure that transforms that rely on the done state work appropriately |
418 | | - env.append_transform(StepCounter(max_steps=10)) |
419 | | - |
420 | | - def policy(td): |
421 | | - td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( |
422 | | - td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64 |
423 | | - ) |
424 | | - return td |
425 | | - |
426 | | - r = env.rollout(100, policy, break_when_any_done=False) |
427 | | - if assign_done: |
428 | | - assert "terminated" in r |
429 | | - assert "done" in r |
430 | | - |
431 | | - |
432 | 72 | class TestChatEnv: |
433 | 73 | @pytest.fixture |
434 | 74 | def tokenizer(self): |
|
0 commit comments