Skip to content

Commit bc2ce69

Browse files
authored
Adding in function to toggle on and off model assisted labeling (#82)
* Adding in function to toggle on and off model assisted labeling * deleted unnecessary file * updating file
1 parent 8e5da18 commit bc2ce69

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

labelbox/schema/project.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,34 @@ def create_prediction(self, label, data_row, prediction_model=None):
367367
res = self.client.execute(query_str, params)
368368
return Prediction(self.client, res["createPrediction"])
369369

370+
def enable_model_assisted_labeling(self, toggle: bool=True) -> bool:
371+
""" Turns model assisted labeling either on or off based on input
372+
373+
Args:
374+
toggle (Boolean): True or False boolean
375+
Returns:
376+
True if toggled on or False if toggled off
377+
"""
378+
379+
project_param = "project_id"
380+
show_param = "show"
381+
382+
query_str = """mutation toggle_model_assisted_labelingPyApi($%s: ID!, $%s: Boolean!) {
383+
project(where: {id: $%s }) {
384+
showPredictionsToLabelers(show: $%s) {
385+
id, showingPredictionsToLabelers
386+
}
387+
}
388+
}""" % (project_param, show_param,project_param, show_param)
389+
390+
params = {
391+
project_param: self.uid,
392+
show_param: toggle
393+
}
394+
395+
res = self.client.execute(query_str, params)
396+
return res["project"]["showPredictionsToLabelers"]["showingPredictionsToLabelers"]
397+
370398
def upload_annotations(
371399
self,
372400
name: str,
@@ -432,7 +460,6 @@ def _is_url_valid(url: Union[str, Path]) -> bool:
432460
raise ValueError(
433461
f'Invalid annotations given of type: {type(annotations)}')
434462

435-
436463
class LabelingParameterOverride(DbObject):
437464
priority = Field.Int("priority")
438465
number_of_labels = Field.Int("number_of_labels")
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from labelbox import Client
2+
from labelbox import Project
3+
import pytest
4+
5+
def test_enable_model_assisted_labeling(project):
6+
response = project.enable_model_assisted_labeling()
7+
assert response == True
8+
9+
response = project.enable_model_assisted_labeling(True)
10+
assert response == True
11+
12+
response = project.enable_model_assisted_labeling(False)
13+
assert response == False

0 commit comments

Comments
 (0)