2222from typing import Callable , Mapping , Sequence , TypeVar , Union
2323
2424from absl import logging
25+ from etils import epath
2526import pandas as pd
2627from tensorflow_datasets .datasets .smart_buildings import constants
2728from tensorflow_datasets .datasets .smart_buildings import reader_lib
@@ -45,8 +46,8 @@ class ProtoReader(reader_lib.BaseReader):
4546 input_dir: directory path where the files are located
4647 """
4748
48- def __init__ (self , input_dir ):
49- self ._input_dir = input_dir
49+ def __init__ (self , input_dir : epath . PathLike ):
50+ self ._input_dir = epath . Path ( input_dir )
5051 logging .info ('Reader lib input directory %s' , self ._input_dir )
5152
5253 def read_observation_responses (
@@ -97,7 +98,7 @@ def read_reward_responses( # pytype: disable=signature-mismatch # overriding-r
9798
9899 def read_zone_infos (self ) -> Sequence [smart_control_building_pb2 .ZoneInfo ]:
99100 """Reads the zone infos for the Building from .pbtxt."""
100- filename = os . path . join ( self ._input_dir , constants .ZONE_INFO_PREFIX )
101+ filename = self ._input_dir / constants .ZONE_INFO_PREFIX
101102 return self ._read_streamed_protos (
102103 filename , smart_control_building_pb2 .ZoneInfo .FromString
103104 )
@@ -107,7 +108,7 @@ def read_device_infos(
107108 ) -> Sequence [smart_control_building_pb2 .DeviceInfo ]:
108109 """Reads the device infos for the Building."""
109110
110- filename = os . path . join ( self ._input_dir , constants .DEVICE_INFO_PREFIX )
111+ filename = self ._input_dir / constants .DEVICE_INFO_PREFIX
111112 return self ._read_streamed_protos (
112113 filename , smart_control_building_pb2 .DeviceInfo .FromString
113114 )
@@ -141,28 +142,26 @@ def _read_messages(
141142 messages .extend (file_messages )
142143 return messages
143144
144- def _read_shards (self , input_dir : str , file_prefix : str ) -> Sequence [str ]:
145+ def _read_shards (
146+ self , input_dir : epath .Path , file_prefix : str
147+ ) -> Sequence [epath .Path ]:
145148 """Returns full paths in input_dir of files starting with file_prefix."""
146-
147- shards = [
148- os .path .join (input_dir , f )
149- for f in os .listdir (input_dir )
150- if f .startswith (file_prefix )
151- ]
152- return shards
149+ return list (epath .Path (input_dir ).glob (f'{ file_prefix } *' ))
153150
154151 def _select_shards (
155152 self ,
156153 start_time : pd .Timestamp ,
157154 end_time : pd .Timestamp ,
158- shards : Sequence [str ],
159- ) -> Sequence [str ]:
155+ shards : Sequence [epath . Path ],
156+ ) -> Sequence [epath . Path ]:
160157 """Returns the shards that fall inside the start and end times."""
161158
162- def _read_timestamp (filepath : str ) -> pd .Timestamp :
159+ def _read_timestamp (filepath : epath . Path ) -> pd .Timestamp :
163160 """Reads the timestamp from the filepath."""
164161 assert filepath
165- ts = pd .Timestamp (re .findall (r'\d{4}\.\d{2}\.\d{2}\.\d{2}' , filepath )[- 1 ])
162+ ts = pd .Timestamp (
163+ re .findall (r'\d{4}\.\d{2}\.\d{2}\.\d{2}' , os .fspath (filepath ))[- 1 ]
164+ )
166165 return ts
167166
168167 def _between (
@@ -179,13 +178,13 @@ def _between(
179178
180179 def _read_streamed_protos (
181180 self ,
182- full_path : str ,
181+ full_path : epath . Path ,
183182 from_string_func : Callable [[Union [bytearray , bytes , memoryview ]], T ],
184183 ) -> Sequence [T ]:
185184 """Reads a proto which has byte size preceding the message."""
186185
187186 messages = []
188- with open (full_path , 'rb' ) as f :
187+ with full_path . open ('rb' ) as f :
189188 while True :
190189 # Read size as a varint
191190 size_bytes = f .read (4 )
@@ -260,7 +259,7 @@ def get_episode_data(working_dir: str) -> pd.DataFrame:
260259 Returns:
261260 A dataframe with episode label, timestamps, number of updates.
262261 """
263- episode_dirs = os . listdir (working_dir )
262+ episode_dirs = list ( epath . Path (working_dir ). iterdir () )
264263 date_extractor = operator .itemgetter (slice (- 13 , None ))
265264
266265 execution_times = pd .to_datetime (
0 commit comments