1- from enum import Enum
2- from labelbox .schema .enums import AnnotationImportState , ImportType
3- from typing import Any , Dict , List
1+ from typing import Any , Dict , List , Union
42import functools
53import os
64import json
1210import requests
1311
1412import labelbox
13+ from labelbox .schema .enums import AnnotationImportState
1514from labelbox .orm .db_object import DbObject
1615from labelbox .orm .model import Field , Relationship
1716from labelbox .orm import query
2120
2221
2322class AnnotationImport (DbObject ):
24- # This class will replace BulkImportRequest.
25- # Currently this exists for the MEA beta.
26- # Use BulkImportRequest for now if you are not using MEA.
27-
28- id_name : str
29- import_type : ImportType
30-
3123 name = Field .String ("name" )
3224 state = Field .Enum (AnnotationImportState , "state" )
3325 input_file_url = Field .String ("input_file_url" )
@@ -36,6 +28,10 @@ class AnnotationImport(DbObject):
3628
3729 created_by = Relationship .ToOne ("User" , False , "created_by" )
3830
31+ parent_id : str
32+ _mutation : str
33+ _parent_id_field : str
34+
3935 @property
4036 def inputs (self ) -> List [Dict [str , Any ]]:
4137 """
@@ -123,20 +119,12 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
123119 return ndjson .loads (response .text )
124120
125121 @classmethod
126- def _build_import_predictions_query (cls , file_args : str , vars : str ):
127- raise NotImplementedError ("" )
128-
129- @classmethod
130- def validate_cls (cls ):
131- supported_base_classes = {MALPredictionImport , MEAPredictionImport }
132- if cls not in {MALPredictionImport , MEAPredictionImport }:
133- raise TypeError (
134- f"Can't directly use the base AnnotationImport class. Must use one of { supported_base_classes } "
135- )
136-
137- @classmethod
138- def from_name (cls , client , parent_id , name : str , raw = False ):
139- cls .validate_cls ()
122+ def _from_name (cls ,
123+ client : "labelbox.Client" ,
124+ parent_id : str ,
125+ name : str ,
126+ raw = False
127+ ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
140128 query_str = """query getImportPyApi($parent_id : ID!, $name: String!) {
141129 annotationImport(
142130 where: {%s: $parent_id, name: $name}){
@@ -145,7 +133,7 @@ def from_name(cls, client, parent_id, name: str, raw=False):
145133 ... on ModelErrorAnalysisPredictionImport {%s}
146134 }}""" % \
147135 (
148- cls .id_name ,
136+ cls ._parent_id_field ,
149137 query .results_query_part (MALPredictionImport ),
150138 query .results_query_part (MEAPredictionImport )
151139 )
@@ -159,19 +147,6 @@ def from_name(cls, client, parent_id, name: str, raw=False):
159147
160148 return cls (client , response ['annotationImport' ])
161149
162- @classmethod
163- def _create_from_url (cls , client , parent_id , name , url ):
164- file_args = "fileUrl : $fileUrl"
165- query_str = cls ._build_import_predictions_query (file_args ,
166- "$fileUrl: String!" )
167- response = client .execute (query_str ,
168- params = {
169- "fileUrl" : url ,
170- "parent_id" : parent_id ,
171- 'name' : name
172- })
173- return cls (client , response ['createAnnotationImport' ])
174-
175150 @staticmethod
176151 def _make_file_name (parent_id : str , name : str ) -> str :
177152 return f"{ parent_id } __{ name } .ndjson"
@@ -180,131 +155,160 @@ def refresh(self) -> None:
180155 """Synchronizes values of all fields with the database.
181156 """
182157 cls = type (self )
183- res = cls .from_name (self .client ,
184- self .get_parent_id (),
185- self .name ,
186- raw = True )
158+ res = cls ._from_name (self .client , self .parent_id , self .name , raw = True )
187159 self ._set_field_values (res )
188160
189161 @classmethod
190- def _create_from_bytes (cls , client , parent_id , name , bytes_data ,
191- content_len ):
162+ def _create_from_bytes (
163+ cls , client : "labelbox.Client" , parent_id : str , name : str ,
164+ bytes_data : bytes , content_len : int
165+ ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
192166 file_name = cls ._make_file_name (parent_id , name )
193- file_args = """filePayload: {
194- file: $file,
195- contentLength: $contentLength
196- }"""
197- query_str = cls ._build_import_predictions_query (
198- file_args , "$file: Upload!, $contentLength: Int!" )
199167 variables = {
200168 "file" : None ,
201169 "contentLength" : content_len ,
202- "parent_id " : parent_id ,
170+ "parentId " : parent_id ,
203171 "name" : name
204172 }
173+ query_str = cls ._get_file_mutation ()
205174 operations = json .dumps ({"variables" : variables , "query" : query_str })
206175 data = {
207176 "operations" : operations ,
208177 "map" : (None , json .dumps ({file_name : ["variables.file" ]}))
209178 }
210179 file_data = (file_name , bytes_data , NDJSON_MIME_TYPE )
211180 files = {file_name : file_data }
212-
213- print (data )
214- breakpoint ()
215- return client .execute (data = data , files = files )
181+ return cls (client ,
182+ client .execute (data = data , files = files )[cls ._mutation ])
216183
217184 @classmethod
218- def _create_from_objects (cls , client , parent_id , name , predictions ):
185+ def _create_from_objects (
186+ cls , client : "labelbox.Client" , parent_id : str , name : str ,
187+ predictions : List [Dict [str , Any ]]
188+ ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
219189 data_str = ndjson .dumps (predictions )
220190 if not data_str :
221191 raise ValueError ('annotations cannot be empty' )
222192 data = data_str .encode ('utf-8' )
223193 return cls ._create_from_bytes (client , parent_id , name , data , len (data ))
224194
225195 @classmethod
226- def _create_from_file (cls , client , parent_id , name , path ):
196+ def _create_from_url (
197+ cls , client : "labelbox.Client" , parent_id : str , name : str ,
198+ url : str ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
199+ if requests .head (url ):
200+ query_str = cls ._get_url_mutation ()
201+ return cls (
202+ client ,
203+ client .execute (query_str ,
204+ params = {
205+ "fileUrl" : url ,
206+ "parentId" : parent_id ,
207+ 'name' : name
208+ })[cls ._mutation ])
209+ else :
210+ raise ValueError (f"Url { url } is not reachable" )
211+
212+ @classmethod
213+ def _create_from_file (
214+ cls , client : "labelbox.Client" , parent_id : str , name : str ,
215+ path : str ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
227216 if os .path .exists (path ):
228217 with open (path , 'rb' ) as f :
229218 return cls ._create_from_bytes (client , parent_id , name , f ,
230219 os .stat (path ).st_size )
231- elif requests .head (path ):
232- return cls ._create_from_url (client , parent_id , name , path )
233- raise ValueError (
234- f"Path { path } is not accessible locally or on a remote server" )
235-
236- def create_from_objects (* args , ** kwargs ):
237- raise NotImplementedError ("" )
220+ else :
221+ raise ValueError (f"File { path } is not accessible" )
238222
239- def create_from_file (* args , ** kwargs ):
240- raise NotImplementedError ("" )
223+ @classmethod
224+ def _get_url_mutation (cls ) -> str :
225+ return """mutation create%sPyApi($parentId : ID!, $name: String!, $fileUrl: String!) {
226+ %s(data: {
227+ %s: $parentId
228+ name: $name
229+ fileUrl: $fileUrl
230+ }) {%s}
231+ }""" % (cls .__class__ .__name__ , cls ._mutation , cls ._parent_id_field ,
232+ query .results_query_part (cls ))
241233
242- def get_parent_id (* args , ** kwargs ):
243- raise NotImplementedError ("" )
234+ @classmethod
235+ def _get_file_mutation (cls ) -> str :
236+ return """mutation create%sPyApi($parentId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
237+ %s(data: { %s : $parentId name: $name filePayload: { file: $file, contentLength: $contentLength}
238+ }) {%s}
239+ }""" % (cls .__class__ .__name__ , cls ._mutation , cls ._parent_id_field ,
240+ query .results_query_part (cls ))
244241
245242
246243class MEAPredictionImport (AnnotationImport ):
247- id_name = "modelRunId"
248- import_type = ImportType .MODEL_ERROR_ANALYSIS
249244 model_run_id = Field .String ("model_run_id" )
245+ _mutation = "createModelErrorAnalysisPredictionImport"
246+ _parent_id_field = "modelRunId"
250247
251- def get_parent_id (self ):
248+ @property
249+ def parent_id (self ) -> str :
252250 return self .model_run_id
253251
254252 @classmethod
255- def create_from_file (cls , client , model_run_id , name , path ):
256- breakpoint ()
257- return cls ( client , cls ._create_from_file (client = client ,
253+ def create_from_file (cls , client : "labelbox.Client" , model_run_id : str ,
254+ name : str , path : str ) -> "MEAPredictionImport" :
255+ return cls ._create_from_file (client = client ,
258256 parent_id = model_run_id ,
259257 name = name ,
260- path = path )[ 'createModelErrorAnalysisPredictionImport' ])
258+ path = path )
261259
262260 @classmethod
263- def create_from_objects (cls , client , model_run_id , name , predictions ):
264- return cls (client , cls ._create_from_objects (client , model_run_id , name , predictions )['createModelErrorAnalysisPredictionImport' ])
261+ def create_from_objects (cls , client : "labelbox.Client" , model_run_id : str ,
262+ name , predictions ) -> "MEAPredictionImport" :
263+ return cls ._create_from_objects (client , model_run_id , name , predictions )
265264
266265 @classmethod
267- def _build_import_predictions_query (cls , file_args : str , vars : str ):
268- query_str = """mutation createAnnotationImportPyApi($parent_id : ID!, $name: String!, %s) {
269- createModelErrorAnalysisPredictionImport(data: {
270- %s : $parent_id
271- name: $name
272- %s
273- }) {%s}
274- }""" % (vars , cls .id_name , file_args ,query .results_query_part (cls ))
275- return query_str
266+ def create_from_url (cls , client : "labelbox.Client" , model_run_id : str ,
267+ name : str , url : str ) -> "MEAPredictionImport" :
268+ return cls ._create_from_url (client = client ,
269+ parent_id = model_run_id ,
270+ name = name ,
271+ url = url )
272+
273+ @classmethod
274+ def from_name (
275+ cls , client : "labelbox.Client" , model_run_id : str ,
276+ name : str ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
277+ return cls ._from_name (client , model_run_id , name )
276278
277279
278280class MALPredictionImport (AnnotationImport ):
279- id_name = "projectId"
280- import_type = ImportType .MODEL_ASSISTED_LABELING
281281 project = Relationship .ToOne ("Project" , cache = True )
282+ _mutation = "createModelAssistedLabelingPredictionImport"
283+ _parent_id_field = "projectId"
282284
283- def get_parent_id (self ):
285+ @property
286+ def parent_id (self ) -> str :
284287 return self .project ().uid
285288
286289 @classmethod
287- def create_from_file (cls , client , project_id , name , path ):
288- return cls (client , cls ._create_from_file (client = client ,
290+ def create_from_file (cls , client : "labelbox.Client" , project_id : str ,
291+ name : str , path : str ) -> "MALPredictionImport" :
292+ return cls ._create_from_file (client = client ,
289293 parent_id = project_id ,
290294 name = name ,
291- path = path )[ 'createModelAssistedLabelingPredictionImport' ])
295+ path = path )
292296
293297 @classmethod
294- def create_from_objects (cls , client , project_id , name , predictions ):
295- return cls (client , cls ._create_from_objects (client , project_id , name , predictions )['createModelAssistedLabelingPredictionImport' ])
298+ def create_from_objects (cls , client : "labelbox.Client" , project_id : str ,
299+ name , predictions ) -> "MALPredictionImport" :
300+ return cls ._create_from_objects (client , project_id , name , predictions )
296301
297302 @classmethod
298- def _build_import_predictions_query (cls , file_args : str , vars : str ):
299- query_str = """mutation createAnnotationImportPyApi($parent_id : ID!, $name: String!, %s) {
300- createModelAssistedLabelingPredictionImport(data: {
301- %s : $parent_id
302- name: $name
303- %s
304- }) {%s}
305- }""" % (vars , cls .id_name , file_args ,
306- query .results_query_part (cls ))
307- return query_str
308-
309-
303+ def create_from_url (cls , client : "labelbox.Client" , project_id : str ,
304+ name : str , url : str ) -> "MALPredictionImport" :
305+ return cls ._create_from_url (client = client ,
306+ parent_id = project_id ,
307+ name = name ,
308+ url = url )
310309
310+ @classmethod
311+ def from_name (
312+ cls , client : "labelbox.Client" , project_id : str ,
313+ name : str ) -> Union ["MEAPredictionImport" , "MALPredictionImport" ]:
314+ return cls ._from_name (client , project_id , name )
0 commit comments