Skip to content

Commit 815b885

Browse files
committed
Changes
1 parent 6e2affd commit 815b885

File tree

9 files changed

+706
-21
lines changed

9 files changed

+706
-21
lines changed

libs/labelbox/src/labelbox/alignerr/alignerr_project.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from labelbox.alignerr.schema.project_rate import ProjectRateV2
77
from labelbox.alignerr.schema.project_domain import ProjectDomain
8+
from labelbox.alignerr.schema.enchanced_resource_tags import EnhancedResourceTag, ResourceTagType
89
from labelbox.pagination import PaginatedCollection
910

1011
logger = logging.getLogger(__name__)
@@ -43,11 +44,6 @@ def project(self, project: "Project"):
4344
self._project = project
4445

4546
def domains(self) -> PaginatedCollection:
46-
"""Get all domains associated with this project.
47-
48-
Returns:
49-
PaginatedCollection of ProjectDomain instances
50-
"""
5147
return ProjectDomain.get_by_project_id(
5248
client=self.client, project_id=self.project.uid
5349
)
@@ -59,7 +55,7 @@ def add_domain(self, project_domain: ProjectDomain):
5955
domain_ids=[project_domain.uid],
6056
)
6157

62-
def get_project_rate(self) -> Optional["ProjectRateV2"]:
58+
def get_project_rates(self) -> list["ProjectRateV2"]:
6359
return ProjectRateV2.get_by_project_id(
6460
client=self.client, project_id=self.project.uid
6561
)
@@ -71,6 +67,71 @@ def set_project_rate(self, project_rate_input):
7167
project_rate_input=project_rate_input,
7268
)
7369

70+
def set_tags(self, tag_names: list[str], tag_type: ResourceTagType):
71+
# Convert tag names to tag IDs
72+
tag_ids = []
73+
for tag_name in tag_names:
74+
# Search for the tag by text to get its ID
75+
found_tags = EnhancedResourceTag.search_by_text(self.client, search_text=tag_name, tag_type=tag_type)
76+
if found_tags:
77+
tag_ids.append(found_tags[0].id)
78+
79+
# Use the existing project resource tag functionality with IDs
80+
self.project.update_project_resource_tags(tag_ids)
81+
return self
82+
83+
def get_tags(self) -> list[EnhancedResourceTag]:
84+
"""Get enhanced resource tags associated with this project.
85+
86+
Returns:
87+
List of EnhancedResourceTag instances
88+
"""
89+
# Get project resource tags and convert to EnhancedResourceTag instances
90+
project_resource_tags = self.project.get_resource_tags()
91+
enhanced_tags = []
92+
for tag in project_resource_tags:
93+
# Search for the corresponding EnhancedResourceTag by text (try different types)
94+
found_tags = []
95+
for tag_type in [ResourceTagType.Default, ResourceTagType.Billing]:
96+
found_tags = EnhancedResourceTag.search_by_text(self.client, search_text=tag.text, tag_type=tag_type)
97+
if found_tags:
98+
break
99+
if found_tags:
100+
enhanced_tags.extend(found_tags)
101+
return enhanced_tags
102+
103+
def add_tag(self, tag: EnhancedResourceTag):
104+
"""Add a single enhanced resource tag to the project.
105+
106+
Args:
107+
tag: EnhancedResourceTag instance to add
108+
109+
Returns:
110+
Self for method chaining
111+
"""
112+
current_tags = self.get_tags()
113+
current_tag_names = [t.text for t in current_tags]
114+
115+
if tag.text not in current_tag_names:
116+
current_tag_names.append(tag.text)
117+
self.set_tags(current_tag_names)
118+
119+
return self
120+
121+
def remove_tag(self, tag: EnhancedResourceTag):
122+
"""Remove a single enhanced resource tag from the project.
123+
124+
Args:
125+
tag: EnhancedResourceTag instance to remove
126+
127+
Returns:
128+
Self for method chaining
129+
"""
130+
current_tags = self.get_tags()
131+
current_tag_names = [t.text for t in current_tags if t.uid != tag.uid]
132+
self.set_tags(current_tag_names)
133+
return self
134+
74135

75136
class AlignerrWorkspace:
76137
def __init__(self, client: "Client"):

libs/labelbox/src/labelbox/alignerr/alignerr_project_builder.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
from labelbox.alignerr.schema.project_rate import BillingMode
66
from labelbox.alignerr.schema.project_rate import ProjectRateInput
7-
from labelbox.alignerr.schema.project_rate import ProjectRateV2
87
from labelbox.alignerr.schema.project_domain import ProjectDomain
8+
from labelbox.alignerr.schema.enchanced_resource_tags import EnhancedResourceTag, ResourceTagType
99
from labelbox.schema.media_type import MediaType
1010

1111
logger = logging.getLogger(__name__)
@@ -22,6 +22,7 @@ def __init__(self, client: "Client"):
2222
self._alignerr_rates: dict[str, ProjectRateInput] = {}
2323
self._customer_rate: ProjectRateInput = None
2424
self._domains: list[ProjectDomain] = []
25+
self._enhanced_resource_tags: list[EnhancedResourceTag] = []
2526
self.role_name_to_id = self._get_role_name_to_id()
2627

2728
def set_name(self, name: str):
@@ -110,6 +111,37 @@ def set_domains(self, domains: list[str]):
110111
self._domains.append(domain_result)
111112
return self
112113

114+
def set_tags(self, tag_texts: list[str], tag_type: ResourceTagType):
115+
"""Set enhanced resource tags for the project.
116+
117+
Args:
118+
tag_texts: List of tag text values to search for and attach
119+
tag_type: Type filter for searching tags
120+
121+
Returns:
122+
Self for method chaining
123+
"""
124+
for tag_text in tag_texts:
125+
# Search for existing tags by text
126+
existing_tags = EnhancedResourceTag.search_by_text(
127+
self.client, search_text=tag_text, tag_type=tag_type
128+
)
129+
130+
if existing_tags:
131+
# Use the first matching tag
132+
self._enhanced_resource_tags.append(existing_tags[0])
133+
else:
134+
# Create new tag if not found
135+
new_tag = EnhancedResourceTag.create(
136+
self.client,
137+
text=tag_text,
138+
color="#007bff", # Default blue color
139+
tag_type=tag_type
140+
)
141+
self._enhanced_resource_tags.append(new_tag)
142+
return self
143+
144+
113145
def create(self, skip_validation: bool = False):
114146
if not skip_validation:
115147
self._validate()
@@ -130,6 +162,7 @@ def create(self, skip_validation: bool = False):
130162

131163
self._create_rates(alignerr_project)
132164
self._create_domains(alignerr_project)
165+
self._create_resource_tags(alignerr_project)
133166

134167
return alignerr_project
135168

@@ -150,6 +183,25 @@ def _create_domains(self, alignerr_project: "AlignerrProject"):
150183
domain_ids=domain_ids,
151184
)
152185

186+
def _create_resource_tags(self, alignerr_project: "AlignerrProject"):
187+
if self._enhanced_resource_tags:
188+
logger.info(
189+
f"Setting enhanced resource tags: {[tag.text for tag in self._enhanced_resource_tags]}"
190+
)
191+
# Group tags by type and set them accordingly
192+
tags_by_type = {}
193+
for tag in self._enhanced_resource_tags:
194+
tag_type = tag.type
195+
if tag_type not in tags_by_type:
196+
tags_by_type[tag_type] = []
197+
tags_by_type[tag_type].append(tag.text)
198+
199+
# Set tags for each type
200+
for tag_type_str, tag_names in tags_by_type.items():
201+
# Convert string back to enum
202+
tag_type_enum = ResourceTagType(tag_type_str)
203+
alignerr_project.set_tags(tag_names, tag_type_enum)
204+
153205
def _validate_alignerr_rates(self):
154206
# Import here to avoid circular imports
155207
from labelbox.alignerr.alignerr_project import AlignerrRole
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from enum import Enum
2+
from typing import List, Optional
3+
from labelbox.orm.db_object import DbObject, Updateable
4+
from labelbox.orm.model import Field
5+
from pydantic import BaseModel
6+
7+
8+
class ResourceTagType(Enum):
9+
"""Enum for resource tag types."""
10+
Default = "Default"
11+
System = "System"
12+
Request = "Request"
13+
Migration = "Migration"
14+
Billing = "Billing"
15+
16+
17+
class CreateResourceTagInput(BaseModel):
18+
"""Input for creating a new resource tag."""
19+
20+
text: str
21+
color: str
22+
type: Optional[str] = None
23+
24+
25+
class UpdateResourceTagInput(BaseModel):
26+
"""Input for updating a resource tag."""
27+
28+
id: str
29+
text: str
30+
color: str
31+
type: Optional[str] = None
32+
33+
34+
class DeleteResourceTagInput(BaseModel):
35+
"""Input for deleting a resource tag."""
36+
37+
id: str
38+
type: Optional[str] = None
39+
40+
41+
class ResourceTagsInput(BaseModel):
42+
"""Input for querying resource tags."""
43+
44+
type: str
45+
46+
47+
class EnhancedResourceTag(DbObject, Updateable):
48+
"""Enhanced resource tag with additional functionality and type support."""
49+
50+
# Fields matching the DDL schema
51+
id = Field.String("id")
52+
createdAt = Field.DateTime("createdAt")
53+
updatedAt = Field.DateTime("updatedAt")
54+
organizationId = Field.String("organizationId")
55+
text = Field.String("text")
56+
color = Field.String("color")
57+
createdById = Field.String("createdById")
58+
type = Field.String("type")
59+
60+
@classmethod
61+
def create(
62+
cls, client, text: str, color: str, tag_type: Optional[ResourceTagType] = None
63+
) -> "EnhancedResourceTag":
64+
"""Create a new enhanced resource tag.
65+
66+
Args:
67+
client: Labelbox client instance
68+
text: Text content of the resource tag
69+
color: Color of the resource tag
70+
tag_type: Optional type of the resource tag
71+
72+
Returns:
73+
Created EnhancedResourceTag instance
74+
"""
75+
# Use the existing organization create_resource_tag method
76+
# Get the organization
77+
org = client.get_organization()
78+
79+
# Create the tag using existing API
80+
tag_data = {"text": text, "color": color}
81+
created_tag = org.create_resource_tag(tag_data)
82+
83+
# Create EnhancedResourceTag with the same data plus defaults for missing fields
84+
enhanced_tag = cls(client, {
85+
"id": created_tag.uid,
86+
"text": created_tag.text,
87+
"color": created_tag.color,
88+
"createdAt": None,
89+
"updatedAt": None,
90+
"organizationId": None,
91+
"createdById": None,
92+
"type": tag_type.value if tag_type else None
93+
})
94+
95+
return enhanced_tag
96+
97+
98+
99+
@classmethod
100+
def search_by_text(
101+
cls, client, search_text: str, tag_type: ResourceTagType
102+
) -> List["EnhancedResourceTag"]:
103+
"""Search resource tags by text content.
104+
105+
Args:
106+
client: Labelbox client instance
107+
search_text: Text to search for
108+
tag_type: Type filter
109+
110+
Returns:
111+
List of matching EnhancedResourceTag instances
112+
"""
113+
# Use the existing organization get_resource_tags method
114+
# Get the organization
115+
org = client.get_organization()
116+
117+
# Get all resource tags
118+
regular_tags = org.get_resource_tags()
119+
120+
# Convert to EnhancedResourceTag instances and filter by search text and type
121+
matching_tags = []
122+
for tag in regular_tags:
123+
if search_text.lower() in tag.text.lower():
124+
enhanced_tag = cls(client, {
125+
"id": tag.uid,
126+
"text": tag.text,
127+
"color": tag.color,
128+
"createdAt": None,
129+
"updatedAt": None,
130+
"organizationId": None,
131+
"createdById": None,
132+
"type": tag_type.value
133+
})
134+
135+
# Apply type filter
136+
if enhanced_tag.type == tag_type.value:
137+
matching_tags.append(enhanced_tag)
138+
139+
return matching_tags

libs/labelbox/src/labelbox/alignerr/schema/project_rate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class ProjectRateV2(DbObject, Deletable):
5353
effectiveUntil = Field.DateTime("effectiveUntil")
5454

5555
@classmethod
56-
def get_by_project_id(cls, client, project_id: str) -> "ProjectRateV2":
56+
def get_by_project_id(cls, client, project_id: str) -> list["ProjectRateV2"]:
5757
query_str = """
5858
query GetAllProjectRatesPyApi($projectId: ID!) {
5959
project(where: { id: $projectId }) {
@@ -84,10 +84,10 @@ def get_by_project_id(cls, client, project_id: str) -> "ProjectRateV2":
8484
rates_data = result["project"]["ratesV2"]
8585

8686
if not rates_data:
87-
return None
87+
return []
8888

89-
# Return the first rate as a ProjectRateV2 object
90-
return cls(client, rates_data[0])
89+
# Return all rates as ProjectRateV2 objects
90+
return [cls(client, rate_data) for rate_data in rates_data]
9191

9292
@classmethod
9393
def set_project_rate(

libs/labelbox/tests/integration/test_alignerr_project.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,16 @@ def test_alignerr_project_domains(client, test_alignerr_project):
7373
# The collection might be empty for a new project, which is expected
7474

7575

76-
def test_alignerr_project_get_project_rate_no_rates(
76+
def test_alignerr_project_get_project_rates_no_rates(
7777
client, test_alignerr_project
7878
):
79-
"""Test get_project_rate() when no rates are set."""
80-
# For a new project without rates, this should return None
81-
project_rate = test_alignerr_project.get_project_rate()
82-
assert project_rate is None
79+
"""Test get_project_rates() when no rates are set."""
80+
# For a new project without rates, this should return an empty list
81+
project_rates = test_alignerr_project.get_project_rates()
82+
assert project_rates == []
8383

8484

85-
def test_alignerr_project_set_and_get_project_rate(
85+
def test_alignerr_project_set_and_get_project_rates(
8686
client, test_alignerr_project
8787
):
8888
"""Test setting and getting project rates."""
@@ -100,7 +100,10 @@ def test_alignerr_project_set_and_get_project_rate(
100100
result = test_alignerr_project.set_project_rate(project_rate_input)
101101
assert result is True # Should return success status
102102

103-
# Get the project rate back
104-
project_rate = test_alignerr_project.get_project_rate()
103+
# Get the project rates back
104+
project_rates = test_alignerr_project.get_project_rates()
105+
# Should return a list with at least one rate
106+
assert isinstance(project_rates, list)
107+
assert len(project_rates) >= 1
105108
# Note: The actual rate retrieval might depend on the API implementation
106109
# This test verifies the method calls work without errors

0 commit comments

Comments
 (0)