3333from roboflow .util .annotations import amend_data_yaml
3434from roboflow .util .general import write_line
3535from roboflow .util .model_processor import process
36- from roboflow .util .versions import get_wrong_dependencies_versions , normalize_yolo_model_type
36+ from roboflow .util .versions import get_model_format , get_wrong_dependencies_versions , normalize_yolo_model_type
3737
3838if TYPE_CHECKING :
3939 import numpy as np
@@ -235,7 +235,7 @@ def download(self, model_format=None, location=None, overwrite: bool = False):
235235
236236 return Dataset (self .name , self .version , model_format , os .path .abspath (location ))
237237
238- def export (self , model_format = None ):
238+ def export (self , model_format = None ) -> str :
239239 """
240240 Ask the Roboflow API to generate a version's dataset in a given format so that it can be downloaded via the `download()` method.
241241
@@ -245,7 +245,7 @@ def export(self, model_format=None):
245245 model_format (str): A format to use for downloading
246246
247247 Returns:
248- True
248+ The URL of the exported dataset.
249249
250250 Raises:
251251 RuntimeError: If the Roboflow API returns an error with a helpful JSON body
@@ -283,7 +283,7 @@ def export(self, model_format=None):
283283 sys .stdout .write ("\n " )
284284 print ("\r " + "Version export complete for " + model_format + " format" )
285285 sys .stdout .flush ()
286- return True
286+ return url
287287 else :
288288 try :
289289 raise RuntimeError (response .json ())
@@ -310,26 +310,17 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
310310
311311 self .__wait_if_generating ()
312312
313- train_model_format = "yolov5pytorch"
314-
315- if self .type == TYPE_CLASSICATION :
316- train_model_format = "folder"
317-
318- if self .type == TYPE_INSTANCE_SEGMENTATION :
319- train_model_format = "yolov5pytorch"
320-
321- if self .type == TYPE_SEMANTIC_SEGMENTATION :
322- train_model_format = "png-mask-semantic"
323-
324- # if classification
325- if train_model_format not in self .exports :
326- self .export (train_model_format )
313+ train_model_format = get_model_format (model_type )
327314
328315 workspace , project , * _ = self .id .rsplit ("/" )
329316 url = f"{ API_URL } /{ workspace } /{ project } /{ self .version } /train"
317+ link = self .export (train_model_format )
330318
331319 data = {}
332320
321+ if link :
322+ data ["link" ] = link
323+
333324 if speed :
334325 data ["speed" ] = speed
335326
@@ -341,6 +332,7 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
341332 data ["modelType" ] = model_type
342333
343334 write_line ("Reaching out to Roboflow to start training..." )
335+ print (data )
344336
345337 response = requests .post (url , json = data , params = {"api_key" : self .__api_key })
346338 if not response .ok :
0 commit comments