diff --git a/tensorflow_datasets/rlds/rlds_base.py b/tensorflow_datasets/rlds/rlds_base.py index 9e2f7ce1b53..c880ef15717 100644 --- a/tensorflow_datasets/rlds/rlds_base.py +++ b/tensorflow_datasets/rlds/rlds_base.py @@ -75,6 +75,8 @@ def build_info( episode_metadata = ds_config.episode_metadata_info if episode_metadata is None: episode_metadata = {} + else: + episode_metadata = {"episode_metadata": episode_metadata} step_info = { 'is_terminal': tf.bool, 'is_first': tf.bool, @@ -163,9 +165,24 @@ def _generate_examples_from_log_path( key_prefix = os.path.basename(log_path) with envlogger.Reader(log_path) as reader: for episode_dict in envlogger_reader.generate_episodes(reader): + assert "steps" in episode_dict, "steps must be in episode_dict" + if "episode_metadata" not in episode_dict: + episode_metadata = { + key: value + for key, value in episode_dict.items() + if key != "steps" + } + episode_dict = { + "steps": episode_dict["steps"], + **( + {"episode_metadata": episode_metadata} + if episode_metadata + else {} + ), + } # The example ID should be unique. episode_id = counter - if 'episode_id' in episode_dict: - episode_id = episode_dict['episode_id'] + if 'episode_id' in episode_dict.get("episode_metadata", {}): + episode_id = episode_dict["episode_metadata"]['episode_id'] yield f'{key_prefix}/{episode_id}', episode_dict counter += 1