@@ -367,6 +367,34 @@ def create_prediction(self, label, data_row, prediction_model=None):
367367 res = self .client .execute (query_str , params )
368368 return Prediction (self .client , res ["createPrediction" ])
369369
370+ def enable_model_assisted_labeling (self , toggle : bool = True ) -> bool :
371+ """ Turns model assisted labeling either on or off based on input
372+
373+ Args:
374+ toggle (Boolean): True or False boolean
375+ Returns:
376+ True if toggled on or False if toggled off
377+ """
378+
379+ project_param = "project_id"
380+ show_param = "show"
381+
382+ query_str = """mutation toggle_model_assisted_labelingPyApi($%s: ID!, $%s: Boolean!) {
383+ project(where: {id: $%s }) {
384+ showPredictionsToLabelers(show: $%s) {
385+ id, showingPredictionsToLabelers
386+ }
387+ }
388+ }""" % (project_param , show_param ,project_param , show_param )
389+
390+ params = {
391+ project_param : self .uid ,
392+ show_param : toggle
393+ }
394+
395+ res = self .client .execute (query_str , params )
396+ return res ["project" ]["showPredictionsToLabelers" ]["showingPredictionsToLabelers" ]
397+
370398 def upload_annotations (
371399 self ,
372400 name : str ,
@@ -432,7 +460,6 @@ def _is_url_valid(url: Union[str, Path]) -> bool:
432460 raise ValueError (
433461 f'Invalid annotations given of type: { type (annotations )} ' )
434462
435-
436463class LabelingParameterOverride (DbObject ):
437464 priority = Field .Int ("priority" )
438465 number_of_labels = Field .Int ("number_of_labels" )
0 commit comments