Skip to content

Commit e954e97

Browse files
[Tests] Fix vmas seeding test (#3210)
1 parent 071d079 commit e954e97

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

test/test_libs.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import collections
8+
import copy
89
import functools
910
import gc
1011
import importlib.util
@@ -2811,14 +2812,27 @@ def test_vmas_seeding(self, scenario_name):
28112812
final_seed = []
28122813
tdreset = []
28132814
tdrollout = []
2814-
for _ in range(2):
2815-
env = VmasEnv(
2815+
rollout_length = 10
2816+
2817+
def create_env():
2818+
return VmasEnv(
28162819
scenario=scenario_name,
28172820
num_envs=4,
28182821
)
2822+
2823+
env = create_env()
2824+
td_actions = [env.action_spec.rand() for _ in range(rollout_length)]
2825+
2826+
for _ in range(2):
2827+
env = create_env()
2828+
td_actions_buffer = copy.deepcopy(td_actions)
2829+
2830+
def policy(td, actions=td_actions_buffer):
2831+
return actions.pop(0)
2832+
28192833
final_seed.append(env.set_seed(0))
28202834
tdreset.append(env.reset())
2821-
tdrollout.append(env.rollout(max_steps=10))
2835+
tdrollout.append(env.rollout(max_steps=rollout_length, policy=policy))
28222836
env.close()
28232837
del env
28242838
assert final_seed[0] == final_seed[1]

0 commit comments

Comments
 (0)