11import datetime
2- from typing import TYPE_CHECKING , Optional
2+ from enum import Enum
3+ from typing import TYPE_CHECKING , Optional , Union , List
34import logging
45
56from labelbox .alignerr .schema .project_rate import BillingMode
67from labelbox .alignerr .schema .project_rate import ProjectRateInput
78from labelbox .alignerr .schema .project_domain import ProjectDomain
89from labelbox .alignerr .schema .enchanced_resource_tags import EnhancedResourceTag , ResourceTagType
10+ from labelbox .alignerr .schema .project_boost_workforce import ProjectBoostWorkforce
911from labelbox .schema .media_type import MediaType
1012
1113logger = logging .getLogger (__name__ )
1214
1315
16+ class ValidationType (Enum ):
17+ """Enum for validation types that can be selectively skipped."""
18+ ALIGNERR_RATE = "AlignerrRate"
19+ CUSTOMER_RATE = "CustomerRate"
20+ PROJECT_OWNER = "ProjectOwner"
21+
22+
1423if TYPE_CHECKING :
1524 from labelbox import Client
1625 from labelbox .alignerr .alignerr_project import AlignerrProject , AlignerrRole
@@ -23,6 +32,7 @@ def __init__(self, client: "Client"):
2332 self ._customer_rate : ProjectRateInput = None
2433 self ._domains : list [ProjectDomain ] = []
2534 self ._enhanced_resource_tags : list [EnhancedResourceTag ] = []
35+ self ._project_owner_email : Optional [str ] = None
2636 self .role_name_to_id = self ._get_role_name_to_id ()
2737
2838 def set_name (self , name : str ):
@@ -141,10 +151,24 @@ def set_tags(self, tag_texts: list[str], tag_type: ResourceTagType):
141151 self ._enhanced_resource_tags .append (new_tag )
142152 return self
143153
154+ def set_project_owner (self , project_owner_email : str ):
155+ """Set the project owner for the ProjectBoostWorkforce.
156+
157+ Args:
158+ project_owner_email: Email of the user to set as project owner
159+
160+ Returns:
161+ Self for method chaining
162+ """
163+ self ._project_owner_email = project_owner_email
164+ return self
144165
145- def create (self , skip_validation : bool = False ):
166+
167+ def create (self , skip_validation : Union [bool , List [ValidationType ]] = False ):
146168 if not skip_validation :
147169 self ._validate ()
170+ elif isinstance (skip_validation , list ):
171+ self ._validate_selective (skip_validation )
148172 logger .info ("Creating project" )
149173
150174 project_data = {
@@ -163,6 +187,7 @@ def create(self, skip_validation: bool = False):
163187 self ._create_rates (alignerr_project )
164188 self ._create_domains (alignerr_project )
165189 self ._create_resource_tags (alignerr_project )
190+ self ._create_project_owner (alignerr_project )
166191
167192 return alignerr_project
168193
@@ -202,6 +227,22 @@ def _create_resource_tags(self, alignerr_project: "AlignerrProject"):
202227 tag_type_enum = ResourceTagType (tag_type_str )
203228 alignerr_project .set_tags (tag_names , tag_type_enum )
204229
230+ def _create_project_owner (self , alignerr_project : "AlignerrProject" ):
231+ if self ._project_owner_email :
232+ logger .info (f"Setting project owner: { self ._project_owner_email } " )
233+
234+ # Find user by email in the organization
235+ user_id = self ._find_user_by_email (self ._project_owner_email )
236+ if not user_id :
237+ current_org = self .client .get_organization ()
238+ raise ValueError (f"User with email { self ._project_owner_email } not found in organization { current_org .uid } " )
239+
240+ ProjectBoostWorkforce .set_project_owner (
241+ client = self .client ,
242+ project_id = alignerr_project .project .uid ,
243+ project_owner_user_id = user_id
244+ )
245+
205246 def _validate_alignerr_rates (self ):
206247 # Import here to avoid circular imports
207248 from labelbox .alignerr .alignerr_project import AlignerrRole
@@ -221,10 +262,61 @@ def _validate_customer_rate(self):
221262 if self ._customer_rate is None :
222263 raise ValueError ("Customer rate is not set" )
223264
265+ def _validate_project_owner (self ):
266+ if self ._project_owner_email is None :
267+ raise ValueError ("Project owner is not set" )
268+
224269 def _validate (self ):
225270 self ._validate_alignerr_rates ()
226271 self ._validate_customer_rate ()
272+ self ._validate_project_owner ()
273+
274+ def _validate_selective (self , skip_validations : List [ValidationType ]):
275+ """Run validations selectively, skipping those in the provided list.
276+
277+ Args:
278+ skip_validations: List of ValidationType enums to skip
279+ """
280+ if ValidationType .ALIGNERR_RATE not in skip_validations :
281+ self ._validate_alignerr_rates ()
282+
283+ if ValidationType .CUSTOMER_RATE not in skip_validations :
284+ self ._validate_customer_rate ()
285+
286+ if ValidationType .PROJECT_OWNER not in skip_validations :
287+ self ._validate_project_owner ()
227288
228289 def _get_role_name_to_id (self ) -> dict [str , str ]:
229290 roles = self .client .get_roles ()
230291 return {role .name : role .uid for role in roles .values ()}
292+
293+ def _find_user_by_email (self , email : str ) -> Optional [str ]:
294+ """Find user ID by email in the organization.
295+
296+ Args:
297+ email: Email address to search for
298+
299+ Returns:
300+ User ID if found, None otherwise
301+ """
302+ try :
303+ # Import here to avoid circular imports
304+ from labelbox .schema .user import User
305+
306+ # Get the current organization
307+ current_org = self .client .get_organization ()
308+
309+ # Use client.get_users with where clause to find user by email
310+ users = self .client .get_users (where = User .email == email )
311+
312+ # Get the first matching user and verify they belong to the same organization
313+ user = next (users , None )
314+ if user and user .organization ().uid == current_org .uid :
315+ return user .uid
316+ else :
317+ logger .warning (f"User with email { email } not found in organization { current_org .uid } " )
318+ return None
319+
320+ except Exception as e :
321+ logger .error (f"Error finding user by email { email } : { e } " )
322+ return None
0 commit comments