Skip to content

Commit 335df57

Browse files
committed
Fixing format export for CLI/python sdk
1 parent 35e775c commit 335df57

File tree

4 files changed

+67
-20
lines changed

4 files changed

+67
-20
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from roboflow.models import CLIPModel, GazeModel # noqa: F401
1616
from roboflow.util.general import write_line
1717

18-
__version__ = "1.2.4"
18+
__version__ = "1.2.6"
1919

2020

2121
def check_key(api_key, model, notebook, num_retries=0):

roboflow/core/version.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from roboflow.util.annotations import amend_data_yaml
3434
from roboflow.util.general import write_line
3535
from roboflow.util.model_processor import process
36-
from roboflow.util.versions import get_wrong_dependencies_versions, normalize_yolo_model_type
36+
from roboflow.util.versions import get_model_format, get_wrong_dependencies_versions, normalize_yolo_model_type
3737

3838
if TYPE_CHECKING:
3939
import numpy as np
@@ -235,7 +235,7 @@ def download(self, model_format=None, location=None, overwrite: bool = False):
235235

236236
return Dataset(self.name, self.version, model_format, os.path.abspath(location))
237237

238-
def export(self, model_format=None):
238+
def export(self, model_format=None) -> str:
239239
"""
240240
Ask the Roboflow API to generate a version's dataset in a given format so that it can be downloaded via the `download()` method.
241241
@@ -245,7 +245,7 @@ def export(self, model_format=None):
245245
model_format (str): A format to use for downloading
246246
247247
Returns:
248-
True
248+
The URL of the exported dataset.
249249
250250
Raises:
251251
RuntimeError: If the Roboflow API returns an error with a helpful JSON body
@@ -283,7 +283,7 @@ def export(self, model_format=None):
283283
sys.stdout.write("\n")
284284
print("\r" + "Version export complete for " + model_format + " format")
285285
sys.stdout.flush()
286-
return True
286+
return url
287287
else:
288288
try:
289289
raise RuntimeError(response.json())
@@ -310,26 +310,17 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
310310

311311
self.__wait_if_generating()
312312

313-
train_model_format = "yolov5pytorch"
314-
315-
if self.type == TYPE_CLASSICATION:
316-
train_model_format = "folder"
317-
318-
if self.type == TYPE_INSTANCE_SEGMENTATION:
319-
train_model_format = "yolov5pytorch"
320-
321-
if self.type == TYPE_SEMANTIC_SEGMENTATION:
322-
train_model_format = "png-mask-semantic"
323-
324-
# if classification
325-
if train_model_format not in self.exports:
326-
self.export(train_model_format)
313+
train_model_format = get_model_format(model_type)
327314

328315
workspace, project, *_ = self.id.rsplit("/")
329316
url = f"{API_URL}/{workspace}/{project}/{self.version}/train"
317+
link = self.export(train_model_format)
330318

331319
data = {}
332320

321+
if link:
322+
data["link"] = link
323+
333324
if speed:
334325
data["speed"] = speed
335326

@@ -341,6 +332,7 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
341332
data["modelType"] = model_type
342333

343334
write_line("Reaching out to Roboflow to start training...")
335+
print(data)
344336

345337
response = requests.post(url, json=data, params={"api_key": self.__api_key})
346338
if not response.ok:

roboflow/util/versions.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,40 @@ def normalize_yolo_model_type(model_type: str) -> str:
9595
model_type = model_type.replace("yolo11", "yolov11")
9696
model_type = model_type.replace("yolo12", "yolov12")
9797
return model_type
98+
99+
100+
def get_model_format(model_type: str) -> str:
101+
"""
102+
Get the model format for a given model type.
103+
Args:
104+
model_type (str): The model type to get the format for.
105+
106+
Returns:
107+
str: The model format.
108+
109+
Example:
110+
>>> get_model_format("yolov5v6n")
111+
"yolov5pytorch"
112+
>>> get_model_format("rfdetr-nano")
113+
"coco"
114+
>>> get_model_format("yolov11n")
115+
"yolov5pytorch"
116+
"""
117+
# Prefixes extrated from modelRegistry.js in roboflow.
118+
model_formats = {
119+
"yolo": "yolov5pytorch",
120+
"pali": "jsonl",
121+
"flor": "jsonl",
122+
"qwen": "jsonl",
123+
"smol": "jsonl",
124+
"vit-b": "folder",
125+
"resn": "folder",
126+
"rfdetr": "coco",
127+
"rf-detr": "coco",
128+
"deep": "png-mask-semantic",
129+
}
130+
131+
for prefix, format in model_formats.items():
132+
if prefix in model_type:
133+
return format
134+
return "yolov5pytorch"

tests/util/test_versions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
from importlib import import_module
33

4-
from roboflow.util.versions import get_wrong_dependencies_versions
4+
from roboflow.util.versions import get_wrong_dependencies_versions, get_model_format
55

66

77
class TestVersions(unittest.TestCase):
@@ -23,3 +23,21 @@ def test_wrong_dependencies_versions(self):
2323
wrong_dependencies_versions = get_wrong_dependencies_versions([test])
2424
is_correct_dep = len(wrong_dependencies_versions) == 0
2525
self.assertEqual(is_correct_dep, expected_result)
26+
27+
28+
class TestGetModelFormat(unittest.TestCase):
29+
def test_get_model_format_with_various_ids(self):
30+
cases = [
31+
("yolov5v2s", "yolov5pytorch"),
32+
("yolov11n", "yolov5pytorch"),
33+
("rf-detr-nas-parent", "coco"),
34+
("rfdetr-nano", "coco"),
35+
("vit-base-patch16-224-in21k", "folder"),
36+
("resnet14", "folder"),
37+
("resenet38", "yolov5pytorch"),
38+
("invlid-type", "yolov5pytorch"),
39+
]
40+
41+
for model_type, expected_format in cases:
42+
with self.subTest(model_type=model_type):
43+
self.assertEqual(get_model_format(model_type), expected_format)

0 commit comments

Comments
 (0)