1212from dotenv import load_dotenv
1313from tqdm import tqdm
1414
15+ from roboflow .adapters import rfapi
1516from roboflow .config import (
1617 API_URL ,
1718 APP_URL ,
@@ -92,11 +93,11 @@ def __init__(
9293
9394 version_without_workspace = os .path .basename (str (version ))
9495
95- response = requests . get ( f" { API_URL } / { workspace } / { project } / { self . version } ?api_key= { self . __api_key } " )
96- if response . ok :
97- version_info = response . json ()[ "version" ]
96+ try :
97+ version_response = rfapi . get_version ( self . __api_key , workspace , project , self . version )
98+ version_info = version_response . get ( "version" , {})
9899 has_model = bool (version_info .get ("train" , {}).get ("model" ))
99- else :
100+ except rfapi . RoboflowError :
100101 has_model = False
101102
102103 if not has_model :
@@ -152,16 +153,17 @@ def __init__(
152153
153154 def __check_if_generating (self ):
154155 # check Roboflow API to see if this version is still generating
155-
156- url = f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } ?nocache=true"
157- response = requests .get (url , params = {"api_key" : self .__api_key })
158- response .raise_for_status ()
159- if response .json ()["version" ]["progress" ] is None :
160- progress = 0.0
161- else :
162- progress = float (response .json ()["version" ]["progress" ])
163-
164- return response .json ()["version" ]["generating" ], progress
156+ versiondict = rfapi .get_version (
157+ api_key = self .__api_key ,
158+ workspace_url = self .workspace ,
159+ project_url = self .project ,
160+ version = self .version ,
161+ nocache = True ,
162+ )
163+ version_obj = versiondict .get ("version" , {})
164+ progress = 0.0 if version_obj .get ("progress" ) is None else float (version_obj .get ("progress" ))
165+ generating = bool (version_obj .get ("generating" ) or version_obj .get ("images" , 0 ) == 0 )
166+ return generating , progress
165167
166168 def __wait_if_generating (self , recurse = False ):
167169 # checks if a given version is still in the progress of generating
@@ -219,15 +221,22 @@ def download(self, model_format=None, location=None, overwrite: bool = False):
219221 if self .__api_key == "coco-128-sample" :
220222 link = "https://app.roboflow.com/ds/n9QwXwUK42?key=NnVCe2yMxP"
221223 else :
222- url = self .__get_download_url (model_format )
223- response = requests .get (url , params = {"api_key" : self .__api_key })
224- if response .status_code == 200 :
225- link = response .json ()["export" ]["link" ]
226- else :
227- try :
228- raise RuntimeError (response .json ())
229- except json .JSONDecodeError :
230- response .raise_for_status ()
224+ workspace , project , * _ = self .id .rsplit ("/" )
225+ try :
226+ export_info = rfapi .get_version_export (
227+ api_key = self .__api_key ,
228+ workspace_url = workspace ,
229+ project_url = project ,
230+ version = self .version ,
231+ format = model_format ,
232+ )
233+ except rfapi .RoboflowError as e :
234+ raise RuntimeError (str (e ))
235+
236+ if "ready" in export_info and export_info .get ("ready" ) is False :
237+ raise RuntimeError (export_info )
238+
239+ link = export_info ["export" ]["link" ]
231240
232241 self .__download_zip (link , location , model_format )
233242 self .__extract_zip (location , model_format )
@@ -256,39 +265,36 @@ def export(self, model_format=None):
256265
257266 self .__wait_if_generating ()
258267
259- url = self .__get_download_url ( model_format )
260- response = requests . get ( url , params = { "api_key" : self . __api_key })
261- if not response . ok :
262- try :
263- raise RuntimeError ( response . json ())
264- except json . JSONDecodeError :
265- response . raise_for_status ()
266-
267- # the rest api returns 202 if the export is still in progress
268- if response . status_code == 202 :
269- status_code_check = 202
270- while status_code_check == 202 :
271- time . sleep ( 1 )
272- response = requests . get ( url , params = { "api_key" : self . __api_key } )
273- status_code_check = response . status_code
274- if status_code_check == 202 :
275- progress = response . json ()[ "progress" ]
276- progress_message = (
277- "Exporting format " + model_format + " in progress : " + str ( round ( progress * 100 , 2 )) + "%"
278- )
279- sys . stdout . write ( " \r " + progress_message )
280- sys . stdout . flush ()
281-
282- if response . status_code == 200 :
268+ workspace , project , * _ = self .id . rsplit ( "/" )
269+ export_info = rfapi . get_version_export (
270+ api_key = self . __api_key ,
271+ workspace_url = workspace ,
272+ project_url = project ,
273+ version = self . version ,
274+ format = model_format ,
275+ )
276+ while "ready" in export_info and export_info . get ( "ready" ) is False :
277+ progress = export_info . get ( "progress" , 0.0 )
278+ progress_message = (
279+ "Exporting format " + model_format + " in progress : " + str ( round ( progress * 100 , 2 )) + "%"
280+ )
281+ sys . stdout . write ( " \r " + progress_message )
282+ sys . stdout . flush ()
283+ time . sleep ( 1 )
284+ export_info = rfapi . get_version_export (
285+ api_key = self . __api_key ,
286+ workspace_url = workspace ,
287+ project_url = project ,
288+ version = self . version ,
289+ format = model_format ,
290+ )
291+ if "export" in export_info :
283292 sys .stdout .write ("\n " )
284293 print ("\r " + "Version export complete for " + model_format + " format" )
285294 sys .stdout .flush ()
286295 return True
287296 else :
288- try :
289- raise RuntimeError (response .json ())
290- except json .JSONDecodeError :
291- response .raise_for_status ()
297+ raise RuntimeError (f"Unexpected export { export_info } " )
292298
293299 def train (self , speed = None , model_type = None , checkpoint = None , plot_in_notebook = False ) -> InferenceModel :
294300 """
@@ -326,28 +332,22 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
326332 self .export (train_model_format )
327333
328334 workspace , project , * _ = self .id .rsplit ("/" )
329- url = f"{ API_URL } /{ workspace } /{ project } /{ self .version } /train"
330335
331- data = {}
332-
333- if speed :
334- data ["speed" ] = speed
335-
336- if checkpoint :
337- data ["checkpoint" ] = checkpoint
338-
339- if model_type :
340- # API expects camelCase key
341- data ["modelType" ] = model_type
336+ payload_speed = speed if speed else None
337+ payload_checkpoint = checkpoint if checkpoint else None
338+ payload_model_type = model_type if model_type else None
342339
343340 write_line ("Reaching out to Roboflow to start training..." )
344341
345- response = requests .post (url , json = data , params = {"api_key" : self .__api_key })
346- if not response .ok :
347- try :
348- raise RuntimeError (response .json ())
349- except json .JSONDecodeError :
350- response .raise_for_status ()
342+ rfapi .start_version_training (
343+ api_key = self .__api_key ,
344+ workspace_url = workspace ,
345+ project_url = project ,
346+ version = self .version ,
347+ speed = payload_speed ,
348+ checkpoint = payload_checkpoint ,
349+ model_type = payload_model_type ,
350+ )
351351
352352 status = "training"
353353
@@ -374,10 +374,14 @@ def live_plot(epochs, mAP, loss, title=""):
374374 num_machine_spin_dots = []
375375
376376 while status == "training" or status == "running" :
377- url = f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } ?nocache=true"
378- response = requests .get (url , params = {"api_key" : self .__api_key })
379- response .raise_for_status ()
380- version = response .json ()["version" ]
377+ version_response = rfapi .get_version (
378+ api_key = self .__api_key ,
379+ workspace_url = self .workspace ,
380+ project_url = self .project ,
381+ version = self .version ,
382+ nocache = True ,
383+ )
384+ version = version_response .get ("version" , {})
381385 if "models" in version .keys ():
382386 models = version ["models" ]
383387 else :
0 commit comments