11from enum import Enum
2- from typing import Set , List , Union , Iterator , Optional
2+ from typing import Set , Iterator
3+ from collections import defaultdict
34
45from labelbox import Client
56from labelbox .exceptions import ResourceCreationError
67from labelbox .pydantic_compat import BaseModel
78from labelbox .schema .user import User
89from labelbox .schema .project import Project
910from labelbox .exceptions import UnprocessableEntityError , InvalidQueryError
11+ from labelbox .schema .queue_mode import QueueMode
12+ from labelbox .schema .ontology_kind import EditorTaskType
13+ from labelbox .schema .media_type import MediaType
1014
1115
1216class UserGroupColor (Enum ):
@@ -34,82 +38,32 @@ class UserGroupColor(Enum):
3438 YELLOW = "E7BF00"
3539 GRAY = "B8C4D3"
3640
37-
38- class UserGroupUser (BaseModel ):
39- """
40- Represents a user in a group.
41-
42- Attributes:
43- id (str): The ID of the user.
44- email (str): The email of the user.
45- """
46- id : str
47- email : str
48-
49- def __hash__ (self ):
50- return hash ((self .id ))
51-
52- def __eq__ (self , other ):
53- if not isinstance (other , UserGroupUser ):
54- return False
55- return self .id == other .id
56-
57-
58- class UserGroupProject (BaseModel ):
59- """
60- Represents a project in a group.
61-
62- Attributes:
63- id (str): The ID of the project.
64- name (str): The name of the project.
65- """
66- id : str
67- name : str
68-
69- def __hash__ (self ):
70- return hash ((self .id ))
71-
72- def __eq__ (self , other ):
73- """
74- Check if this GroupProject object is equal to another GroupProject object.
75-
76- Args:
77- other (GroupProject): The other GroupProject object to compare with.
78-
79- Returns:
80- bool: True if the two GroupProject objects are equal, False otherwise.
81- """
82- if not isinstance (other , UserGroupProject ):
83- return False
84- return self .id == other .id
85-
8641
8742class UserGroup (BaseModel ):
8843 """
8944 Represents a user group in Labelbox.
9045
9146 Attributes:
92- id (Optional[ str] ): The ID of the user group.
93- name (Optional[ str] ): The name of the user group.
47+ id (str): The ID of the user group.
48+ name (str): The name of the user group.
9449 color (UserGroupColor): The color of the user group.
9550 users (Set[UserGroupUser]): The set of users in the user group.
9651 projects (Set[UserGroupProject]): The set of projects associated with the user group.
9752 client (Client): The Labelbox client object.
9853
9954 Methods:
100- __init__(self, client: Client, id: str = "", name: str = "", color: UserGroupColor = UserGroupColor.BLUE,
101- users: Set[UserGroupUser] = set(), projects: Set[UserGroupProject] = set(), reload=True)
102- _reload(self)
55+ __init__(self, client: Client)
56+ get(self) -> "UserGroup"
10357 update(self) -> "UserGroup"
10458 create(self) -> "UserGroup"
10559 delete(self) -> bool
10660 get_user_groups(client: Client) -> Iterator["UserGroup"]
10761 """
108- id : Optional [ str ]
109- name : Optional [ str ]
62+ id : str
63+ name : str
11064 color : UserGroupColor
111- users : Set [UserGroupUser ]
112- projects : Set [UserGroupProject ]
65+ users : Set [User ]
66+ projects : Set [Project ]
11367 client : Client
11468
11569 class Config :
@@ -122,9 +76,8 @@ def __init__(
12276 id : str = "" ,
12377 name : str = "" ,
12478 color : UserGroupColor = UserGroupColor .BLUE ,
125- users : Set [UserGroupUser ] = set (),
126- projects : Set [UserGroupProject ] = set (),
127- reload = True ,
79+ users : Set [User ] = set (),
80+ projects : Set [Project ] = set ()
12881 ):
12982 """
13083 Initializes a UserGroup object.
@@ -134,36 +87,32 @@ def __init__(
13487 id (str, optional): The ID of the user group. Defaults to an empty string.
13588 name (str, optional): The name of the user group. Defaults to an empty string.
13689 color (UserGroupColor, optional): The color of the user group. Defaults to UserGroupColor.BLUE.
137- users (Set[UserGroupUser], optional): The set of users in the user group. Defaults to an empty set.
138- projects (Set[UserGroupProject], optional): The set of projects associated with the user group. Defaults to an empty set.
139- reload (bool, optional): Whether to reload the partial representation of the group. Defaults to True.
90+ users (Set[User], optional): The set of users in the user group. Defaults to an empty set.
91+ projects (Set[Project], optional): The set of projects associated with the user group. Defaults to an empty set.
14092
14193 Raises:
14294 RuntimeError: If the experimental feature is not enabled in the client.
143-
14495 """
14596 super ().__init__ (client = client , id = id , name = name , color = color , users = users , projects = projects )
14697 if not self .client .enable_experimental :
147- raise RuntimeError (
148- "Please enable experimental in client to use UserGroups" )
98+ raise RuntimeError ("Please enable experimental in client to use UserGroups" )
14999
150- # partial representation of the group, reload
151- if self .id and reload :
152- self ._reload ()
153-
154- def _reload (self ):
100+ def get (self ) -> "UserGroup" :
155101 """
156102 Reloads the user group information from the server.
157103
158104 This method sends a GraphQL query to the server to fetch the latest information
159105 about the user group, including its name, color, projects, and members. The fetched
160106 information is then used to update the corresponding attributes of the `Group` object.
161107
162- Raises :
163- InvalidQueryError: If the query fails to fetch the group information .
108+ Args :
109+ id (str): The ID of the user group to fetch.
164110
165111 Returns:
166- None
112+ UserGroup of passed in ID (self)
113+
114+ Raises:
115+ InvalidQueryError: If the query fails to fetch the group information.
167116 """
168117 query = """
169118 query GetUserGroupPyApi($id: ID!) {
@@ -196,14 +145,9 @@ def _reload(self):
196145 raise InvalidQueryError ("Failed to fetch group" )
197146 self .name = result ["userGroup" ]["name" ]
198147 self .color = UserGroupColor (result ["userGroup" ]["color" ])
199- self .projects = {
200- UserGroupProject (id = project ["id" ], name = project ["name" ])
201- for project in result ["userGroup" ]["projects" ]["nodes" ]
202- }
203- self .users = {
204- UserGroupUser (id = member ["id" ], email = member ["email" ])
205- for member in result ["userGroup" ]["members" ]["nodes" ]
206- }
148+ self .projects = self ._get_projects_set (result ["userGroup" ]["projects" ]["nodes" ])
149+ self .users = self ._get_users_set (result ["userGroup" ]["members" ]["nodes" ])
150+ return self
207151
208152 def update (self ) -> "UserGroup" :
209153 """
@@ -249,10 +193,10 @@ def update(self) -> "UserGroup":
249193 "color" :
250194 self .color .value ,
251195 "projectIds" : [
252- project .id for project in self .projects
196+ project .uid for project in self .projects
253197 ],
254198 "userIds" : [
255- user .id for user in self .users
199+ user .uid for user in self .users
256200 ]
257201 }
258202 result = self .client .execute (query , params )
@@ -311,10 +255,10 @@ def create(self) -> "UserGroup":
311255 "color" :
312256 self .color .value ,
313257 "projectIds" : [
314- project .id for project in self .projects
258+ project .uid for project in self .projects
315259 ],
316260 "userIds" : [
317- user .id for user in self .users
261+ user .uid for user in self .users
318262 ]
319263 }
320264 result = self .client .execute (query , params )
@@ -351,8 +295,7 @@ def delete(self) -> bool:
351295 raise UnprocessableEntityError ("Failed to delete user group" )
352296 return result ["deleteUserGroup" ]["success" ]
353297
354- @staticmethod
355- def get_user_groups (client : Client ) -> Iterator ["UserGroup" ]:
298+ def get_user_groups (self ) -> Iterator ["UserGroup" ]:
356299 """
357300 Gets all user groups in Labelbox.
358301
@@ -390,29 +333,60 @@ def get_user_groups(client: Client) -> Iterator["UserGroup"]:
390333 """
391334 nextCursor = None
392335 while True :
393- userGroups = client .execute (
336+ userGroups = self . client .execute (
394337 query , {"nextCursor" : nextCursor })["userGroups" ]
395338 if not userGroups :
396339 return
397340 yield
398341 groups = userGroups ["nodes" ]
399342 for group in groups :
400- yield UserGroup (client ,
401- reload = False ,
402- id = group ["id" ],
403- name = group ["name" ],
404- color = UserGroupColor (group ["color" ]),
405- users = {
406- UserGroupUser (id = member ["id" ],
407- email = member ["email" ])
408- for member in group ["members" ]["nodes" ]
409- },
410- projects = {
411- UserGroupProject (id = project ["id" ],
412- name = project ["name" ])
413- for project in group ["projects" ]["nodes" ]
414- })
343+ userGroup = UserGroup (self .client )
344+ userGroup .id = group ["id" ]
345+ userGroup .name = group ["name" ]
346+ userGroup .color = UserGroupColor (group ["color" ])
347+ userGroup .users = self ._get_users_set (group ["members" ]["nodes" ])
348+ userGroup .projects = self ._get_projects_set (group ["projects" ]["nodes" ])
349+ yield userGroup
415350 nextCursor = userGroups ["nextCursor" ]
416351 # this doesn't seem to be implemented right now to return a value other than null from the api
417352 if not nextCursor :
418353 break
354+
355+ def _get_users_set (self , user_nodes ):
356+ """
357+ Retrieves a set of User objects from the given user nodes.
358+
359+ Args:
360+ user_nodes (list): A list of user nodes containing user information.
361+
362+ Returns:
363+ set: A set of User objects.
364+ """
365+ users = set ()
366+ for user in user_nodes :
367+ user_values = defaultdict (lambda : None )
368+ user_values ["id" ] = user ["id" ]
369+ user_values ["email" ] = user ["email" ]
370+ users .add (User (self .client , user_values ))
371+ return users
372+
373+ def _get_projects_set (self , project_nodes ):
374+ """
375+ Retrieves a set of projects based on the given project nodes.
376+
377+ Args:
378+ project_nodes (list): A list of project nodes.
379+
380+ Returns:
381+ set: A set of Project objects.
382+ """
383+ projects = set ()
384+ for project in project_nodes :
385+ project_values = defaultdict (lambda : None )
386+ project_values ["id" ] = project ["id" ]
387+ project_values ["name" ] = project ["name" ]
388+ project_values ["queueMode" ] = QueueMode .Batch .value
389+ project_values ["editorTaskType" ] = EditorTaskType .Missing .value
390+ project_values ["mediaType" ] = MediaType .Image .value
391+ projects .add (Project (self .client , project_values ))
392+ return projects
0 commit comments