1616import json
1717import typing
1818import urllib .request
19+ import urllib .error
1920from urllib .parse import urlparse
2021import re
2122import base64
2223from dataclasses import dataclass
2324import argparse
2425
25-
2626def options ():
2727 p = argparse .ArgumentParser (description = __doc__ )
2828 p .add_argument ("--hash-only" , action = "store_true" )
2929 p .add_argument ("sources" , type = pathlib .Path , nargs = "+" )
3030 return p .parse_args ()
3131
3232
33+ TIMEOUT = 20
34+
35+ def warn (message : str ) -> None :
36+ print (f"WARNING: { message } " , file = sys .stderr )
37+
38+
3339@dataclass
3440class Endpoint :
3541 name : str
@@ -41,6 +47,10 @@ def update_headers(self, d: typing.Iterable[typing.Tuple[str, str]]):
4147 self .headers .update ((k .capitalize (), v ) for k , v in d )
4248
4349
50+ class NoEndpointsFound (Exception ):
51+ pass
52+
53+
4454opts = options ()
4555sources = [p .resolve () for p in opts .sources ]
4656source_dir = pathlib .Path (os .path .commonpath (src .parent for src in sources ))
@@ -105,18 +115,12 @@ def get_endpoints() -> typing.Iterable[Endpoint]:
105115 "download" ,
106116 ]
107117 try :
108- res = subprocess .run (cmd , stdout = subprocess .PIPE , timeout = 15 )
118+ res = subprocess .run (cmd , stdout = subprocess .PIPE , timeout = TIMEOUT )
109119 except subprocess .TimeoutExpired :
110- print (
111- f"WARNING: ssh timed out when connecting to { server } , ignoring { endpoint .name } endpoint" ,
112- file = sys .stderr ,
113- )
120+ warn (f"ssh timed out when connecting to { server } , ignoring { endpoint .name } endpoint" )
114121 continue
115122 if res .returncode != 0 :
116- print (
117- f"WARNING: ssh failed when connecting to { server } , ignoring { endpoint .name } endpoint" ,
118- file = sys .stderr ,
119- )
123+ warn (f"ssh failed when connecting to { server } , ignoring { endpoint .name } endpoint" )
120124 continue
121125 ssh_resp = json .loads (res .stdout )
122126 endpoint .href = ssh_resp .get ("href" , endpoint )
@@ -139,10 +143,7 @@ def get_endpoints() -> typing.Iterable[Endpoint]:
139143 input = f"protocol={ url .scheme } \n host={ url .netloc } \n path={ url .path [1 :]} \n " ,
140144 )
141145 if credentials is None :
142- print (
143- f"WARNING: no authorization method found, ignoring { data .name } endpoint" ,
144- file = sys .stderr ,
145- )
146+ warn (f"no authorization method found, ignoring { endpoint .name } endpoint" )
146147 continue
147148 credentials = dict (get_env (credentials ))
148149 auth = base64 .b64encode (
@@ -176,18 +177,18 @@ def get_locations(objects):
176177 data = json .dumps (data ).encode ("ascii" ),
177178 )
178179 try :
179- with urllib .request .urlopen (req ) as resp :
180+ with urllib .request .urlopen (req , timeout = TIMEOUT ) as resp :
180181 data = json .load (resp )
181- except urllib .request . HTTPError as e :
182- print (f"WARNING: encountered HTTPError { e } , ignoring endpoint { e .name } " )
182+ except urllib .error . URLError as e :
183+ warn (f"encountered { type ( e ). __name__ } { e } , ignoring endpoint { endpoint .name } " )
183184 continue
184185 assert len (data ["objects" ]) == len (
185186 indexes
186187 ), f"received { len (data )} objects, expected { len (indexes )} "
187188 for i , resp in zip (indexes , data ["objects" ]):
188189 ret [i ] = f'{ resp ["oid" ]} { resp ["actions" ]["download" ]["href" ]} '
189190 return ret
190- raise Exception ( f"no valid endpoint found" )
191+ raise NoEndpointsFound
191192
192193
193194def get_lfs_object (path ):
@@ -204,6 +205,10 @@ def get_lfs_object(path):
204205 return {"oid" : sha256 , "size" : size }
205206
206207
207- objects = [get_lfs_object (src ) for src in sources ]
208- for resp in get_locations (objects ):
209- print (resp )
208+ try :
209+ objects = [get_lfs_object (src ) for src in sources ]
210+ for resp in get_locations (objects ):
211+ print (resp )
212+ except NoEndpointsFound as e :
213+ print (f"ERROR: no valid endpoints found" , file = sys .stderr )
214+ sys .exit (1 )
0 commit comments