Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit d59fa26

Browse files
author
piotrmilos
committed
Allowing lambda's in the environment specification
1 parent 8e02ef2 commit d59fa26

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tensor2tensor/rl/rl_trainer_lib.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@
3333
import tensorflow as tf
3434

3535

36-
def define_train(hparams, environment_name):
36+
def define_train(hparams, environment_spec):
3737
"""Define the training setup."""
38-
env_lambda = lambda: gym.make(environment_name)
38+
if isinstance(environment_spec, str):
39+
env_lambda = lambda: gym.make(environment_spec)
40+
else:
41+
env_lambda = environment_spec
3942
policy_lambda = hparams.network
4043
env = env_lambda()
4144
action_space = env.action_space

0 commit comments

Comments
 (0)