Skip to content

Commit a2c2973

Browse files
committed
Update
[ghstack-poisoned]
1 parent 5d400b4 commit a2c2973

File tree

2 files changed

+2
-363
lines changed

2 files changed

+2
-363
lines changed

test/llm/test_envs.py

Lines changed: 2 additions & 362 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from __future__ import annotations
66

77
import argparse
8-
import contextlib
98
import importlib.util
109
import random
1110
import re
@@ -14,27 +13,15 @@
1413

1514
import pytest
1615
import torch
17-
from mocking_classes_llm import DummyStrDataLoader, DummyTensorDataLoader
18-
19-
from tensordict import (
20-
lazy_stack,
21-
NonTensorData,
22-
NonTensorStack,
23-
set_capture_non_tensor_stack,
24-
set_list_to_stack,
25-
TensorDict,
26-
)
16+
17+
from tensordict import lazy_stack, set_list_to_stack, TensorDict
2718

2819
from torchrl._utils import logger as torchrl_logger
2920
from torchrl.data.llm.history import History
30-
from torchrl.envs import StepCounter
3121
from torchrl.envs.llm import (
32-
as_padded_tensor,
3322
ChatEnv,
34-
DataLoadingPrimer,
3523
GSM8KEnv,
3624
KLRewardTransform,
37-
LLMEnv,
3825
make_gsm8k_env,
3926
RetrieveKL,
4027
)
@@ -82,353 +69,6 @@ def set_list_to_stack_for_test():
8269
return
8370

8471

85-
class TestLLMEnv:
86-
@pytest.fixture(scope="class", autouse=True)
87-
def set_capture(self):
88-
with set_capture_non_tensor_stack(False):
89-
yield None
90-
return
91-
92-
@pytest.mark.skipif(not _has_transformers, reason="test requires transformers")
93-
@pytest.mark.parametrize(
94-
"from_text,stack_method",
95-
[
96-
[True, None],
97-
[False, "as_padded_tensor"],
98-
# TODO: a bit experimental, fails with check_env_specs
99-
# [False, "as_nested_tensor"],
100-
[False, None],
101-
],
102-
)
103-
@pytest.mark.parametrize("dl_batch_size", [1, 4])
104-
@pytest.mark.parametrize("env_batch_size", [None, 0, (), 4])
105-
@pytest.mark.parametrize("device", [None, "cpu"])
106-
def test_llm_env(
107-
self, from_text, stack_method, device, dl_batch_size, env_batch_size
108-
):
109-
if from_text:
110-
primer = DataLoadingPrimer(
111-
dataloader=DummyStrDataLoader(batch_size=dl_batch_size),
112-
batch_size=env_batch_size,
113-
)
114-
else:
115-
if stack_method is None:
116-
stack_method = as_padded_tensor
117-
primer = DataLoadingPrimer(
118-
dataloader=DummyTensorDataLoader(
119-
batch_size=dl_batch_size, padding=True
120-
),
121-
stack_method=stack_method,
122-
batch_size=env_batch_size,
123-
)
124-
with pytest.warns(UserWarning, match="eos_token_id"):
125-
env = LLMEnv(
126-
from_text=from_text,
127-
device=device,
128-
batch_size=primer.batch_size,
129-
)
130-
env = env.append_transform(primer)
131-
if env_batch_size is None:
132-
assert env.batch_size == torch.Size((dl_batch_size,))
133-
else:
134-
if not isinstance(env_batch_size, tuple):
135-
env_batch_size = (
136-
torch.Size(())
137-
if env_batch_size == 0
138-
else torch.Size((env_batch_size,))
139-
)
140-
assert env.batch_size == env_batch_size
141-
142-
env.check_env_specs(break_when_any_done="both")
143-
144-
@pytest.mark.skipif(not _has_transformers, reason="test requires transformers")
145-
@pytest.mark.parametrize("tokenizer", [True, False])
146-
@pytest.mark.parametrize(
147-
"from_text,stack_method",
148-
[
149-
[True, None],
150-
[False, "as_padded_tensor"],
151-
[False, None],
152-
],
153-
)
154-
@pytest.mark.parametrize("device", [None, "cpu"])
155-
@pytest.mark.parametrize("dl_batch_size", [1, 4])
156-
@pytest.mark.parametrize("env_batch_size", [None, 0, (), 4])
157-
def test_llm_from_dataloader(
158-
self,
159-
from_text,
160-
stack_method,
161-
device,
162-
dl_batch_size,
163-
env_batch_size,
164-
tokenizer,
165-
):
166-
from transformers import AutoTokenizer
167-
168-
if tokenizer and from_text:
169-
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
170-
else:
171-
tokenizer = None
172-
if from_text:
173-
kwargs = {
174-
"dataloader": DummyStrDataLoader(batch_size=dl_batch_size),
175-
}
176-
else:
177-
if stack_method is None:
178-
stack_method = as_padded_tensor
179-
kwargs = {
180-
"dataloader": DummyTensorDataLoader(
181-
padding=True, batch_size=dl_batch_size
182-
),
183-
"stack_method": stack_method,
184-
}
185-
kwargs.update(
186-
{
187-
"batch_size": env_batch_size,
188-
"from_text": from_text,
189-
"device": device,
190-
"has_attention": False,
191-
"tokenizer": tokenizer,
192-
}
193-
)
194-
with pytest.warns(UserWarning, match="eos_token_id"):
195-
env = LLMEnv.from_dataloader(**kwargs)
196-
if env_batch_size is None:
197-
assert env.batch_size == torch.Size((dl_batch_size,))
198-
else:
199-
if not isinstance(env_batch_size, tuple):
200-
env_batch_size = (
201-
torch.Size(())
202-
if env_batch_size == 0
203-
else torch.Size((env_batch_size,))
204-
)
205-
assert env.batch_size == env_batch_size
206-
env.check_env_specs(break_when_any_done="both")
207-
208-
def policy(td):
209-
if from_text and tokenizer is None:
210-
if not td.shape:
211-
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorData(
212-
"<nothing>", device=device
213-
)
214-
else:
215-
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
216-
*[
217-
NonTensorData("<nothing>", device=device)
218-
for _ in range(td.shape[0])
219-
]
220-
)
221-
else:
222-
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
223-
td.shape + (1,), dtype=torch.int64
224-
)
225-
return td
226-
227-
r = env.rollout(10, policy)
228-
if env.batch_size == ():
229-
assert r.ndim == 1
230-
r = r.unsqueeze(0)
231-
else:
232-
assert r.ndim == 2
233-
if from_text and tokenizer is None:
234-
assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str)
235-
assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str)
236-
assert (
237-
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
238-
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
239-
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
240-
]
241-
), (
242-
r[0, 0][LLMEnv._DEFAULT_STR_KEY],
243-
r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY],
244-
r[0, 0]["next", LLMEnv._DEFAULT_STR_KEY],
245-
r[0, 1][LLMEnv._DEFAULT_STR_KEY],
246-
)
247-
assert (
248-
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
249-
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
250-
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
251-
]
252-
)
253-
assert (
254-
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
255-
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
256-
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
257-
]
258-
)
259-
assert (
260-
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
261-
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
262-
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
263-
]
264-
)
265-
elif tokenizer is None:
266-
assert (
267-
r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
268-
== r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
269-
).all()
270-
assert (
271-
r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
272-
== r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
273-
).all()
274-
assert (
275-
r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
276-
== r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
277-
).all()
278-
assert (
279-
r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY]
280-
== r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
281-
).all()
282-
283-
@pytest.mark.parametrize(
284-
"from_text,stack_method",
285-
[
286-
[True, None],
287-
[False, "as_padded_tensor"],
288-
# TODO: a bit experimental, fails with check_env_specs
289-
# [False, "as_nested_tensor"],
290-
[False, None],
291-
],
292-
)
293-
@pytest.mark.parametrize("device", [None, "cpu"])
294-
@pytest.mark.parametrize("dl_batch_size", [1, 4])
295-
@pytest.mark.parametrize("env_batch_size", [None, 0, (), 4])
296-
@pytest.mark.parametrize("repeats", [3])
297-
def test_llm_from_dataloader_repeats(
298-
self, from_text, stack_method, device, env_batch_size, dl_batch_size, repeats
299-
):
300-
if from_text:
301-
kwargs = {
302-
"dataloader": DummyStrDataLoader(batch_size=dl_batch_size),
303-
"repeats": repeats,
304-
}
305-
else:
306-
if stack_method is None:
307-
stack_method = as_padded_tensor
308-
kwargs = {
309-
"dataloader": DummyTensorDataLoader(
310-
padding=True, batch_size=dl_batch_size
311-
),
312-
"stack_method": stack_method,
313-
"repeats": repeats,
314-
}
315-
kwargs.update(
316-
{
317-
"batch_size": env_batch_size,
318-
"from_text": from_text,
319-
"device": device,
320-
"has_attention": False,
321-
}
322-
)
323-
with pytest.warns(UserWarning, match="eos_token_id"):
324-
env = LLMEnv.from_dataloader(**kwargs)
325-
assert env.transform.repeats == repeats
326-
327-
max_steps = 3
328-
env.append_transform(StepCounter(max_steps=max_steps))
329-
330-
def policy(td):
331-
if from_text:
332-
if not td.shape:
333-
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>"
334-
else:
335-
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
336-
*["<nothing>" for _ in range(td.shape[0])]
337-
)
338-
else:
339-
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
340-
td.shape + (1,), dtype=torch.int64
341-
)
342-
return td
343-
344-
r = env.rollout(100, policy, break_when_any_done=False)
345-
# check that r at reset is always the same
346-
r_reset = r[..., ::max_steps]
347-
if from_text:
348-
all_strings = r_reset.view(-1)[LLMEnv._DEFAULT_STR_KEY]
349-
assert sum(s == all_strings[0] for s in all_strings) == repeats
350-
assert sum(s == all_strings[repeats] for s in all_strings) == repeats
351-
assert sum(s == all_strings[repeats * 2] for s in all_strings) == repeats
352-
else:
353-
all_tokens = r_reset.view(-1)[LLMEnv._DEFAULT_TOKEN_KEY]
354-
assert sum((s == all_tokens[0]).all() for s in all_tokens) == repeats
355-
assert sum((s == all_tokens[repeats]).all() for s in all_tokens) == repeats
356-
assert (
357-
sum((s == all_tokens[repeats * 2]).all() for s in all_tokens) == repeats
358-
)
359-
360-
@pytest.mark.parametrize(
361-
"from_text,stack_method",
362-
[
363-
[True, None],
364-
[False, "as_padded_tensor"],
365-
],
366-
)
367-
@pytest.mark.parametrize("device", [None])
368-
@pytest.mark.parametrize("dl_batch_size", [1, 4])
369-
@pytest.mark.parametrize("env_batch_size", [None, 0, (), 4])
370-
@pytest.mark.parametrize("repeats", [3])
371-
@pytest.mark.parametrize(
372-
"assign_reward,assign_done", [[True, False], [True, True], [False, True]]
373-
)
374-
def test_done_and_reward(
375-
self,
376-
from_text,
377-
stack_method,
378-
device,
379-
env_batch_size,
380-
dl_batch_size,
381-
repeats,
382-
assign_reward,
383-
assign_done,
384-
):
385-
with pytest.raises(
386-
ValueError, match="from_text"
387-
) if from_text else contextlib.nullcontext():
388-
if from_text:
389-
kwargs = {
390-
"dataloader": DummyStrDataLoader(batch_size=dl_batch_size),
391-
"repeats": repeats,
392-
"assign_reward": assign_reward,
393-
"assign_done": assign_done,
394-
}
395-
else:
396-
if stack_method is None:
397-
stack_method = as_padded_tensor
398-
kwargs = {
399-
"dataloader": DummyTensorDataLoader(
400-
padding=True, batch_size=dl_batch_size
401-
),
402-
"stack_method": stack_method,
403-
"repeats": repeats,
404-
"assign_reward": assign_reward,
405-
"assign_done": assign_done,
406-
}
407-
kwargs.update(
408-
{
409-
"batch_size": env_batch_size,
410-
"from_text": from_text,
411-
"device": device,
412-
"has_attention": False,
413-
}
414-
)
415-
with pytest.warns(UserWarning, match="eos_token_id"):
416-
env = LLMEnv.from_dataloader(**kwargs)
417-
# We want to make sure that transforms that rely on the done state work appropriately
418-
env.append_transform(StepCounter(max_steps=10))
419-
420-
def policy(td):
421-
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
422-
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
423-
)
424-
return td
425-
426-
r = env.rollout(100, policy, break_when_any_done=False)
427-
if assign_done:
428-
assert "terminated" in r
429-
assert "done" in r
430-
431-
43272
class TestChatEnv:
43373
@pytest.fixture
43474
def tokenizer(self):

torchrl/testing/ray_helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def update_weights(self, modify_weights: bool = False):
159159
Returns:
160160
str: "updated" status message
161161
"""
162-
163162
# Optionally modify weights for testing
164163
if modify_weights:
165164
with torch.no_grad():

0 commit comments

Comments
 (0)