Skip to content

Commit 42ca995

Browse files
committed
Set project owner during project creation
1 parent 815b885 commit 42ca995

File tree

4 files changed

+439
-2
lines changed

4 files changed

+439
-2
lines changed

libs/labelbox/src/labelbox/alignerr/alignerr_project.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from labelbox.alignerr.schema.project_rate import ProjectRateV2
77
from labelbox.alignerr.schema.project_domain import ProjectDomain
88
from labelbox.alignerr.schema.enchanced_resource_tags import EnhancedResourceTag, ResourceTagType
9+
from labelbox.alignerr.schema.project_boost_workforce import ProjectBoostWorkforce
910
from labelbox.pagination import PaginatedCollection
1011

1112
logger = logging.getLogger(__name__)
@@ -132,6 +133,17 @@ def remove_tag(self, tag: EnhancedResourceTag):
132133
self.set_tags(current_tag_names)
133134
return self
134135

136+
def get_project_owner(self) -> Optional[ProjectBoostWorkforce]:
137+
"""Get the ProjectBoostWorkforce for this project.
138+
139+
Returns:
140+
ProjectBoostWorkforce instance or None if not found
141+
"""
142+
return ProjectBoostWorkforce.get_by_project_id(
143+
client=self.client,
144+
project_id=self.project.uid
145+
)
146+
135147

136148
class AlignerrWorkspace:
137149
def __init__(self, client: "Client"):

libs/labelbox/src/labelbox/alignerr/alignerr_project_builder.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
import datetime
2-
from typing import TYPE_CHECKING, Optional
2+
from enum import Enum
3+
from typing import TYPE_CHECKING, Optional, Union, List
34
import logging
45

56
from labelbox.alignerr.schema.project_rate import BillingMode
67
from labelbox.alignerr.schema.project_rate import ProjectRateInput
78
from labelbox.alignerr.schema.project_domain import ProjectDomain
89
from labelbox.alignerr.schema.enchanced_resource_tags import EnhancedResourceTag, ResourceTagType
10+
from labelbox.alignerr.schema.project_boost_workforce import ProjectBoostWorkforce
911
from labelbox.schema.media_type import MediaType
1012

1113
logger = logging.getLogger(__name__)
1214

1315

16+
class ValidationType(Enum):
17+
"""Enum for validation types that can be selectively skipped."""
18+
ALIGNERR_RATE = "AlignerrRate"
19+
CUSTOMER_RATE = "CustomerRate"
20+
PROJECT_OWNER = "ProjectOwner"
21+
22+
1423
if TYPE_CHECKING:
1524
from labelbox import Client
1625
from labelbox.alignerr.alignerr_project import AlignerrProject, AlignerrRole
@@ -23,6 +32,7 @@ def __init__(self, client: "Client"):
2332
self._customer_rate: ProjectRateInput = None
2433
self._domains: list[ProjectDomain] = []
2534
self._enhanced_resource_tags: list[EnhancedResourceTag] = []
35+
self._project_owner_email: Optional[str] = None
2636
self.role_name_to_id = self._get_role_name_to_id()
2737

2838
def set_name(self, name: str):
@@ -141,10 +151,24 @@ def set_tags(self, tag_texts: list[str], tag_type: ResourceTagType):
141151
self._enhanced_resource_tags.append(new_tag)
142152
return self
143153

154+
def set_project_owner(self, project_owner_email: str):
155+
"""Set the project owner for the ProjectBoostWorkforce.
156+
157+
Args:
158+
project_owner_email: Email of the user to set as project owner
159+
160+
Returns:
161+
Self for method chaining
162+
"""
163+
self._project_owner_email = project_owner_email
164+
return self
144165

145-
def create(self, skip_validation: bool = False):
166+
167+
def create(self, skip_validation: Union[bool, List[ValidationType]] = False):
146168
if not skip_validation:
147169
self._validate()
170+
elif isinstance(skip_validation, list):
171+
self._validate_selective(skip_validation)
148172
logger.info("Creating project")
149173

150174
project_data = {
@@ -163,6 +187,7 @@ def create(self, skip_validation: bool = False):
163187
self._create_rates(alignerr_project)
164188
self._create_domains(alignerr_project)
165189
self._create_resource_tags(alignerr_project)
190+
self._create_project_owner(alignerr_project)
166191

167192
return alignerr_project
168193

@@ -202,6 +227,22 @@ def _create_resource_tags(self, alignerr_project: "AlignerrProject"):
202227
tag_type_enum = ResourceTagType(tag_type_str)
203228
alignerr_project.set_tags(tag_names, tag_type_enum)
204229

230+
def _create_project_owner(self, alignerr_project: "AlignerrProject"):
231+
if self._project_owner_email:
232+
logger.info(f"Setting project owner: {self._project_owner_email}")
233+
234+
# Find user by email in the organization
235+
user_id = self._find_user_by_email(self._project_owner_email)
236+
if not user_id:
237+
current_org = self.client.get_organization()
238+
raise ValueError(f"User with email {self._project_owner_email} not found in organization {current_org.uid}")
239+
240+
ProjectBoostWorkforce.set_project_owner(
241+
client=self.client,
242+
project_id=alignerr_project.project.uid,
243+
project_owner_user_id=user_id
244+
)
245+
205246
def _validate_alignerr_rates(self):
206247
# Import here to avoid circular imports
207248
from labelbox.alignerr.alignerr_project import AlignerrRole
@@ -221,10 +262,61 @@ def _validate_customer_rate(self):
221262
if self._customer_rate is None:
222263
raise ValueError("Customer rate is not set")
223264

265+
def _validate_project_owner(self):
266+
if self._project_owner_email is None:
267+
raise ValueError("Project owner is not set")
268+
224269
def _validate(self):
225270
self._validate_alignerr_rates()
226271
self._validate_customer_rate()
272+
self._validate_project_owner()
273+
274+
def _validate_selective(self, skip_validations: List[ValidationType]):
275+
"""Run validations selectively, skipping those in the provided list.
276+
277+
Args:
278+
skip_validations: List of ValidationType enums to skip
279+
"""
280+
if ValidationType.ALIGNERR_RATE not in skip_validations:
281+
self._validate_alignerr_rates()
282+
283+
if ValidationType.CUSTOMER_RATE not in skip_validations:
284+
self._validate_customer_rate()
285+
286+
if ValidationType.PROJECT_OWNER not in skip_validations:
287+
self._validate_project_owner()
227288

228289
def _get_role_name_to_id(self) -> dict[str, str]:
229290
roles = self.client.get_roles()
230291
return {role.name: role.uid for role in roles.values()}
292+
293+
def _find_user_by_email(self, email: str) -> Optional[str]:
294+
"""Find user ID by email in the organization.
295+
296+
Args:
297+
email: Email address to search for
298+
299+
Returns:
300+
User ID if found, None otherwise
301+
"""
302+
try:
303+
# Import here to avoid circular imports
304+
from labelbox.schema.user import User
305+
306+
# Get the current organization
307+
current_org = self.client.get_organization()
308+
309+
# Use client.get_users with where clause to find user by email
310+
users = self.client.get_users(where=User.email == email)
311+
312+
# Get the first matching user and verify they belong to the same organization
313+
user = next(users, None)
314+
if user and user.organization().uid == current_org.uid:
315+
return user.uid
316+
else:
317+
logger.warning(f"User with email {email} not found in organization {current_org.uid}")
318+
return None
319+
320+
except Exception as e:
321+
logger.error(f"Error finding user by email {email}: {e}")
322+
return None

0 commit comments

Comments
 (0)