@@ -40,6 +40,9 @@ class Project(DbObject, Updateable, Deletable):
4040 "LabelingParameterOverride" , False , "labeling_parameter_overrides" )
4141 webhooks = Relationship .ToMany ("Webhook" , False )
4242 benchmarks = Relationship .ToMany ("Benchmark" , False )
43+ active_prediction_model = Relationship .ToOne ("PredictionModel" , False ,
44+ "active_prediction_model" )
45+ predictions = Relationship .ToMany ("Prediction" , False )
4346
4447 def create_label (self , ** kwargs ):
4548 """ Creates a label on this Project.
@@ -283,6 +286,59 @@ def extend_reservations(self, queue_type):
283286 res = self .client .execute (query_str , {id_param : self .uid })
284287 return res ["extendReservations" ]
285288
289+ def create_prediction_model (self , name , version ):
290+ """ Creates a PredictionModel connected to this Project.
291+ Args:
292+ name (str): The new PredictionModel's name.
293+ version (int): The new PredictionModel's version.
294+ Return:
295+ A newly created PredictionModel.
296+ """
297+ PM = Entity .PredictionModel
298+ model = self .client ._create (
299+ PM , {PM .name .name : name , PM .version .name : version })
300+ self .active_prediction_model .connect (model )
301+ return model
302+
303+ def create_prediction (self , label , data_row , prediction_model = None ):
304+ """ Creates a Prediction within this Project.
305+ Args:
306+ label (str): The `label` field of the new Prediction.
307+ data_row (DataRow): The DataRow for which the Prediction is created.
308+ prediction_model (PredictionModel or None): The PredictionModel
309+ within which the new Prediction is created. If None then this
310+ Project's active_prediction_model is used.
311+ Return:
312+ A newly created Prediction.
313+ Raises:
314+ labelbox.excepions.InvalidQueryError: if given `prediction_model`
315+ is None and this Project's active_prediction_model is also
316+ None.
317+ """
318+ if prediction_model is None :
319+ prediction_model = self .active_prediction_model ()
320+ if prediction_model is None :
321+ raise InvalidQueryError (
322+ "Project '%s' has no active prediction model" % self .name )
323+
324+ label_param = "label"
325+ model_param = "prediction_model_id"
326+ project_param = "project_id"
327+ data_row_param = "data_row_id"
328+
329+ Prediction = Entity .Prediction
330+ query_str = """mutation CreatePredictionPyApi(
331+ $%s: String!, $%s: ID!, $%s: ID!, $%s: ID!) {createPrediction(
332+ data: {label: $%s, predictionModelId: $%s, projectId: $%s,
333+ dataRowId: $%s})
334+ {%s}}""" % (label_param , model_param , project_param , data_row_param ,
335+ label_param , model_param , project_param , data_row_param ,
336+ query .results_query_part (Prediction ))
337+ params = {label_param : label , model_param : prediction_model .uid ,
338+ data_row_param : data_row .uid , project_param : self .uid }
339+ res = self .client .execute (query_str , params )
340+ return Prediction (self .client , res ["createPrediction" ])
341+
286342
287343class LabelingParameterOverride (DbObject ):
288344 priority = Field .Int ("priority" )
0 commit comments