1111from google .api_core import retry
1212import requests
1313import requests .exceptions
14+ from labelbox .data .annotation_types .feature import FeatureSchema
15+ from labelbox .data .serialization .ndjson .base import DataRow
1416
1517import labelbox .exceptions
1618from labelbox import utils
2022from labelbox .orm .model import Entity
2123from labelbox .pagination import PaginatedCollection
2224from labelbox .schema .data_row_metadata import DataRowMetadataOntology
25+ from labelbox .schema .dataset import Dataset
2326from labelbox .schema .iam_integration import IAMIntegration
2427from labelbox .schema import role
25- from labelbox .schema .ontology import Tool , Classification
28+ from labelbox .schema .labeling_frontend import LabelingFrontend
29+ from labelbox .schema .model import Model
30+ from labelbox .schema .ontology import Ontology , Tool , Classification
31+ from labelbox .schema .organization import Organization
32+ from labelbox .schema .user import User
33+ from labelbox .schema .project import Project
34+ from labelbox .schema .role import Role
2635
2736logger = logging .getLogger (__name__ )
2837
@@ -411,7 +420,7 @@ def get_project(self, project_id):
411420 """
412421 return self ._get_single (Entity .Project , project_id )
413422
414- def get_dataset (self , dataset_id ):
423+ def get_dataset (self , dataset_id ) -> Dataset :
415424 """ Gets a single Dataset with the given ID.
416425
417426 >>> dataset = client.get_dataset("<dataset_id>")
@@ -426,14 +435,14 @@ def get_dataset(self, dataset_id):
426435 """
427436 return self ._get_single (Entity .Dataset , dataset_id )
428437
429- def get_user (self ):
438+ def get_user (self ) -> User :
430439 """ Gets the current User database object.
431440
432441 >>> user = client.get_user()
433442 """
434443 return self ._get_single (Entity .User , None )
435444
436- def get_organization (self ):
445+ def get_organization (self ) -> Organization :
437446 """ Gets the Organization DB object of the current user.
438447
439448 >>> organization = client.get_organization()
@@ -461,7 +470,7 @@ def _get_all(self, db_object_type, where, filter_deleted=True):
461470 [utils .camel_case (db_object_type .type_name ()) + "s" ],
462471 db_object_type )
463472
464- def get_projects (self , where = None ):
473+ def get_projects (self , where = None ) -> List [ Project ] :
465474 """ Fetches all the projects the user has access to.
466475
467476 >>> projects = client.get_projects(where=(Project.name == "<project_name>") & (Project.description == "<project_description>"))
@@ -474,7 +483,7 @@ def get_projects(self, where=None):
474483 """
475484 return self ._get_all (Entity .Project , where )
476485
477- def get_datasets (self , where = None ):
486+ def get_datasets (self , where = None ) -> List [ Dataset ] :
478487 """ Fetches one or more datasets.
479488
480489 >>> datasets = client.get_datasets(where=(Dataset.name == "<dataset_name>") & (Dataset.description == "<dataset_description>"))
@@ -487,7 +496,7 @@ def get_datasets(self, where=None):
487496 """
488497 return self ._get_all (Entity .Dataset , where )
489498
490- def get_labeling_frontends (self , where = None ):
499+ def get_labeling_frontends (self , where = None ) -> List [ LabelingFrontend ] :
491500 """ Fetches all the labeling frontends.
492501
493502 >>> frontend = client.get_labeling_frontends(where=LabelingFrontend.name == "Editor")
@@ -527,7 +536,9 @@ def _create(self, db_object_type, data):
527536 res = res ["create%s" % db_object_type .type_name ()]
528537 return db_object_type (self , res )
529538
530- def create_dataset (self , iam_integration = IAMIntegration ._DEFAULT , ** kwargs ):
539+ def create_dataset (self ,
540+ iam_integration = IAMIntegration ._DEFAULT ,
541+ ** kwargs ) -> Dataset :
531542 """ Creates a Dataset object on the server.
532543
533544 Attribute values are passed as keyword arguments.
@@ -585,7 +596,7 @@ def create_dataset(self, iam_integration=IAMIntegration._DEFAULT, **kwargs):
585596 raise e
586597 return dataset
587598
588- def create_project (self , ** kwargs ):
599+ def create_project (self , ** kwargs ) -> Project :
589600 """ Creates a Project object on the server.
590601
591602 Attribute values are passed as keyword arguments.
@@ -602,15 +613,15 @@ def create_project(self, **kwargs):
602613 """
603614 return self ._create (Entity .Project , kwargs )
604615
605- def get_roles (self ):
616+ def get_roles (self ) -> List [ Role ] :
606617 """
607618 Returns:
608619 Roles: Provides information on available roles within an organization.
609620 Roles are used for user management.
610621 """
611622 return role .get_roles (self )
612623
613- def get_data_row (self , data_row_id ):
624+ def get_data_row (self , data_row_id ) -> DataRow :
614625 """
615626
616627 Returns:
@@ -619,7 +630,7 @@ def get_data_row(self, data_row_id):
619630
620631 return self ._get_single (Entity .DataRow , data_row_id )
621632
622- def get_data_row_metadata_ontology (self ):
633+ def get_data_row_metadata_ontology (self ) -> DataRowMetadataOntology :
623634 """
624635
625636 Returns:
@@ -628,7 +639,7 @@ def get_data_row_metadata_ontology(self):
628639 """
629640 return DataRowMetadataOntology (self )
630641
631- def get_model (self , model_id ):
642+ def get_model (self , model_id ) -> Model :
632643 """ Gets a single Model with the given ID.
633644
634645 >>> model = client.get_model("<model_id>")
@@ -643,7 +654,7 @@ def get_model(self, model_id):
643654 """
644655 return self ._get_single (Entity .Model , model_id )
645656
646- def get_models (self , where = None ):
657+ def get_models (self , where = None ) -> List [ Model ] :
647658 """ Fetches all the models the user has access to.
648659
649660 >>> models = client.get_models(where=(Model.name == "<model_name>"))
@@ -656,7 +667,7 @@ def get_models(self, where=None):
656667 """
657668 return self ._get_all (Entity .Model , where , filter_deleted = False )
658669
659- def create_model (self , name , ontology_id ):
670+ def create_model (self , name , ontology_id ) -> Model :
660671 """ Creates a Model object on the server.
661672
662673 >>> model = client.create_model(<model_name>, <ontology_id>)
@@ -707,7 +718,7 @@ def get_data_row_ids_for_external_ids(
707718 result [row ['externalId' ]].append (row ['dataRowId' ])
708719 return result
709720
710- def get_ontology (self , ontology_id ):
721+ def get_ontology (self , ontology_id ) -> Ontology :
711722 """
712723 Fetches an Ontology by id.
713724
@@ -718,7 +729,7 @@ def get_ontology(self, ontology_id):
718729 """
719730 return self ._get_single (Entity .Ontology , ontology_id )
720731
721- def get_ontologies (self , name_contains ):
732+ def get_ontologies (self , name_contains ) -> PaginatedCollection :
722733 """
723734 Fetches all ontologies with names that match the name_contains string.
724735
@@ -739,7 +750,7 @@ def get_ontologies(self, name_contains):
739750 ['ontologies' , 'nodes' ], Entity .Ontology ,
740751 ['ontologies' , 'nextCursor' ])
741752
742- def get_feature_schema (self , feature_schema_id ):
753+ def get_feature_schema (self , feature_schema_id ) -> FeatureSchema :
743754 """
744755 Fetches a feature schema. Only supports top level feature schemas.
745756
@@ -760,7 +771,7 @@ def get_feature_schema(self, feature_schema_id):
760771 res ['id' ] = res ['normalized' ]['featureSchemaId' ]
761772 return Entity .FeatureSchema (self , res )
762773
763- def get_feature_schemas (self , name_contains ):
774+ def get_feature_schemas (self , name_contains ) -> PaginatedCollection :
764775 """
765776 Fetches top level feature schemas with names that match the `name_contains` string
766777
@@ -789,7 +800,8 @@ def rootSchemaPayloadToFeatureSchema(client, payload):
789800 rootSchemaPayloadToFeatureSchema ,
790801 ['rootSchemaNodes' , 'nextCursor' ])
791802
792- def create_ontology_from_feature_schemas (self , name , feature_schema_ids ):
803+ def create_ontology_from_feature_schemas (self , name ,
804+ feature_schema_ids ) -> Ontology :
793805 """
794806 Creates an ontology from a list of feature schema ids
795807
@@ -828,7 +840,7 @@ def create_ontology_from_feature_schemas(self, name, feature_schema_ids):
828840 normalized = {'tools' : tools , 'classifications' : classifications }
829841 return self .create_ontology (name , normalized )
830842
831- def create_ontology (self , name , normalized ):
843+ def create_ontology (self , name , normalized ) -> Ontology :
832844 """
833845 Creates an ontology from normalized data
834846 >>> normalized = {"tools" : [{'tool': 'polygon', 'name': 'cat', 'color': 'black'}], "classifications" : []}
@@ -855,7 +867,7 @@ def create_ontology(self, name, normalized):
855867 res = self .execute (query_str , params )
856868 return Entity .Ontology (self , res ['upsertOntology' ])
857869
858- def create_feature_schema (self , normalized ):
870+ def create_feature_schema (self , normalized ) -> FeatureSchema :
859871 """
860872 Creates a feature schema from normalized data.
861873 >>> normalized = {'tool': 'polygon', 'name': 'cat', 'color': 'black'}
0 commit comments