Skip to content

Commit f5c3983

Browse files
Add ability to configure model training epochs
1 parent 4310e56 commit f5c3983

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from roboflow.models import CLIPModel, GazeModel # noqa: F401
1616
from roboflow.util.general import write_line
1717

18-
__version__ = "1.2.10"
18+
__version__ = "1.2.11"
1919

2020

2121
def check_key(api_key, model, notebook, num_retries=0):

roboflow/adapters/rfapi.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def start_version_training(
5858
speed: Optional[str] = None,
5959
checkpoint: Optional[str] = None,
6060
model_type: Optional[str] = None,
61+
epochs: Optional[int] = None,
6162
):
6263
"""
6364
Start a training job for a specific version.
@@ -74,6 +75,8 @@ def start_version_training(
7475
if model_type is not None:
7576
# API expects camelCase
7677
data["modelType"] = model_type
78+
if epochs is not None:
79+
data["epochs"] = epochs
7780

7881
response = requests.post(url, json=data)
7982
if not response.ok:

roboflow/core/version.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,14 +296,15 @@ def export(self, model_format=None) -> bool | None:
296296
else:
297297
raise RuntimeError(f"Unexpected export {export_info}")
298298

299-
def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
299+
def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False, epochs=None) -> InferenceModel:
300300
"""
301301
Ask the Roboflow API to train a previously exported version's dataset.
302302
303303
Args:
304304
speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`.
305305
model_type: The type of model to train. Default depends on kind of project. It takes precedence over speed. You can check the list of model ids by sending an invalid parameter in this argument.
306306
checkpoint: A string representing the checkpoint to use while training
307+
epochs: Number of epochs to train the model
307308
plot: Whether to plot the training results. Default is `False`.
308309
309310
Returns:
@@ -336,6 +337,7 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
336337
speed=payload_speed,
337338
checkpoint=payload_checkpoint,
338339
model_type=payload_model_type,
340+
epochs=epochs,
339341
)
340342

341343
status = "training"

0 commit comments

Comments
 (0)