2424BLOCK_SIZE = 8192
2525
2626
27- def get_local_binary_path (name : str , url : str ) -> str :
27+ def get_local_binary_path (name : str , url : str , tmp_dir : Optional [ str ] = None ) -> str :
2828 """
2929 Returns the path to the executable previously downloaded with the name argument. If
3030 None is found, the executable at the url argument will be downloaded and stored
3131 under name for future uses.
3232 :param name: The name that will be given to the folder containing the extracted data
3333 :param url: The URL of the zip file
34+ :param: tmp_dir: Optional override for the temporary directory to save binaries and zips in.
3435 """
3536 NUMBER_ATTEMPTS = 5
36- with FileLock (os .path .join (tempfile .gettempdir (), name + ".lock" )):
37- path = get_local_binary_path_if_exists (name , url )
37+ tmp_dir = tmp_dir or tempfile .gettempdir ()
38+ lock = FileLock (os .path .join (tmp_dir , name + ".lock" ))
39+ with lock :
40+ path = get_local_binary_path_if_exists (name , url , tmp_dir = tmp_dir )
3841 if path is None :
3942 logger .debug (
4043 f"Local environment { name } not found, downloading environment from { url } "
@@ -45,7 +48,7 @@ def get_local_binary_path(name: str, url: str) -> str:
4548 if path is not None :
4649 break
4750 try :
48- download_and_extract_zip (url , name )
51+ download_and_extract_zip (url , name , tmp_dir = tmp_dir )
4952 except Exception :
5053 if attempt + 1 < NUMBER_ATTEMPTS :
5154 logger .warning (
@@ -54,7 +57,7 @@ def get_local_binary_path(name: str, url: str) -> str:
5457 )
5558 else :
5659 raise
57- path = get_local_binary_path_if_exists (name , url )
60+ path = get_local_binary_path_if_exists (name , url , tmp_dir = tmp_dir )
5861
5962 if path is None :
6063 raise FileNotFoundError (
@@ -64,15 +67,16 @@ def get_local_binary_path(name: str, url: str) -> str:
6467 return path
6568
6669
67- def get_local_binary_path_if_exists (name : str , url : str ) -> Optional [str ]:
70+ def get_local_binary_path_if_exists (name : str , url : str , tmp_dir : str ) -> Optional [str ]:
6871 """
6972 Recursively searches for a Unity executable in the extracted files folders. This is
7073 platform dependent : It will only return a Unity executable compatible with the
7174 computer's OS. If no executable is found, None will be returned.
7275 :param name: The name/identifier of the executable
7376 :param url: The url the executable was downloaded from (for verification)
77+ :param: tmp_dir: Optional override for the temporary directory to save binaries and zips in.
7478 """
75- _ , bin_dir = get_tmp_dir ( )
79+ _ , bin_dir = get_tmp_dirs ( tmp_dir )
7680 extension = None
7781
7882 if platform == "linux" or platform == "linux2" :
@@ -100,27 +104,27 @@ def get_local_binary_path_if_exists(name: str, url: str) -> Optional[str]:
100104 return None
101105
102106
103- def _get_tmp_dir_helper () :
104- TEMPDIR = "/tmp" if platform == "darwin" else tempfile .gettempdir ()
107+ def _get_tmp_dir_helper (tmp_dir : Optional [ str ] = None ) -> Tuple [ str , str ] :
108+ tmp_dir = tmp_dir or ( "/tmp" if platform == "darwin" else tempfile .gettempdir () )
105109 MLAGENTS = "ml-agents-binaries"
106110 TMP_FOLDER_NAME = "tmp"
107111 BINARY_FOLDER_NAME = "binaries"
108- mla_directory = os .path .join (TEMPDIR , MLAGENTS )
112+ mla_directory = os .path .join (tmp_dir , MLAGENTS )
109113 if not os .path .exists (mla_directory ):
110114 os .makedirs (mla_directory )
111115 os .chmod (mla_directory , 16877 )
112- zip_directory = os .path .join (TEMPDIR , MLAGENTS , TMP_FOLDER_NAME )
116+ zip_directory = os .path .join (tmp_dir , MLAGENTS , TMP_FOLDER_NAME )
113117 if not os .path .exists (zip_directory ):
114118 os .makedirs (zip_directory )
115119 os .chmod (zip_directory , 16877 )
116- bin_directory = os .path .join (TEMPDIR , MLAGENTS , BINARY_FOLDER_NAME )
120+ bin_directory = os .path .join (tmp_dir , MLAGENTS , BINARY_FOLDER_NAME )
117121 if not os .path .exists (bin_directory ):
118122 os .makedirs (bin_directory )
119123 os .chmod (bin_directory , 16877 )
120- return ( zip_directory , bin_directory )
124+ return zip_directory , bin_directory
121125
122126
123- def get_tmp_dir ( ) -> Tuple [str , str ]:
127+ def get_tmp_dirs ( tmp_dir : Optional [ str ] = None ) -> Tuple [str , str ]:
124128 """
125129 Returns the path to the folder containing the downloaded zip files and the extracted
126130 binaries. If these folders do not exist, they will be created.
@@ -130,21 +134,24 @@ def get_tmp_dir() -> Tuple[str, str]:
130134 # Should only be able to error out 3 times (once for each subdir).
131135 for _attempt in range (3 ):
132136 try :
133- return _get_tmp_dir_helper ()
137+ return _get_tmp_dir_helper (tmp_dir )
134138 except FileExistsError :
135139 continue
136- return _get_tmp_dir_helper ()
140+ return _get_tmp_dir_helper (tmp_dir )
137141
138142
139- def download_and_extract_zip (url : str , name : str ) -> None :
143+ def download_and_extract_zip (
144+ url : str , name : str , tmp_dir : Optional [str ] = None
145+ ) -> None :
140146 """
141147 Downloads a zip file under a URL, extracts its contents into a folder with the name
142148 argument and gives chmod 755 to all the files it contains. Files are downloaded and
143149 extracted into special folders in the temp folder of the machine.
144150 :param url: The URL of the zip file
145151 :param name: The name that will be given to the folder containing the extracted data
152+ :param: tmp_dir: Optional override for the temporary directory to save binaries and zips in.
146153 """
147- zip_dir , bin_dir = get_tmp_dir ( )
154+ zip_dir , bin_dir = get_tmp_dirs ( tmp_dir )
148155 url_hash = "-" + hashlib .md5 (url .encode ()).hexdigest ()
149156 binary_path = os .path .join (bin_dir , name + url_hash )
150157 if os .path .exists (binary_path ):
@@ -206,7 +213,7 @@ def load_remote_manifest(url: str) -> Dict[str, Any]:
206213 """
207214 Converts a remote yaml file into a Python dictionary
208215 """
209- tmp_dir , _ = get_tmp_dir ()
216+ tmp_dir , _ = get_tmp_dirs ()
210217 try :
211218 request = urllib .request .urlopen (url , timeout = 30 )
212219 except urllib .error .HTTPError as e : # type: ignore
0 commit comments