Skip to content

Commit 214595e

Browse files
authored
Merge branch 'main' into rodrigo/fix-cli-format-export
2 parents 335df57 + 6a6025a commit 214595e

File tree

3 files changed

+155
-2
lines changed

3 files changed

+155
-2
lines changed

roboflow/core/project.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,9 @@ def search(
653653
batch: bool = False,
654654
batch_id: Optional[str] = None,
655655
fields: Optional[List[str]] = None,
656+
*,
657+
annotation_job: Optional[bool] = None,
658+
annotation_job_id: Optional[str] = None,
656659
):
657660
"""
658661
Search for images in a project.
@@ -667,6 +670,8 @@ def search(
667670
in_dataset (str): dataset that an image must be in
668671
batch (bool): whether the image must be in a batch
669672
batch_id (str): batch id that an image must be in
673+
annotation_job (bool): whether the image must be in an annotation job
674+
annotation_job_id (str): annotation job id that an image must be in
670675
fields (list): fields to return in results (default: ["id", "created", "name", "labels"])
671676
672677
Returns:
@@ -684,7 +689,7 @@ def search(
684689
if fields is None:
685690
fields = ["id", "created", "name", "labels"]
686691

687-
payload: Dict[str, Union[str, int, List[str]]] = {}
692+
payload: Dict[str, Union[str, int, bool, List[str]]] = {}
688693

689694
if like_image is not None:
690695
payload["like_image"] = like_image
@@ -713,6 +718,12 @@ def search(
713718
if batch_id is not None:
714719
payload["batch_id"] = batch_id
715720

721+
if annotation_job is not None:
722+
payload["annotation_job"] = annotation_job
723+
724+
if annotation_job_id is not None:
725+
payload["annotation_job_id"] = annotation_job_id
726+
716727
payload["fields"] = fields
717728

718729
data = requests.post(
@@ -734,6 +745,9 @@ def search_all(
734745
batch: bool = False,
735746
batch_id: Optional[str] = None,
736747
fields: Optional[List[str]] = None,
748+
*,
749+
annotation_job: Optional[bool] = None,
750+
annotation_job_id: Optional[str] = None,
737751
):
738752
"""
739753
Create a paginated list of search results for use in searching the images in a project.
@@ -748,6 +762,8 @@ def search_all(
748762
in_dataset (str): dataset that an image must be in
749763
batch (bool): whether the image must be in a batch
750764
batch_id (str): batch id that an image must be in
765+
annotation_job (bool): whether the image must be in an annotation job
766+
annotation_job_id (str): annotation job id that an image must be in
751767
fields (list): fields to return in results (default: ["id", "created", "name", "labels"])
752768
753769
Returns:
@@ -781,6 +797,8 @@ def search_all(
781797
batch=batch,
782798
batch_id=batch_id,
783799
fields=fields,
800+
annotation_job=annotation_job,
801+
annotation_job_id=annotation_job_id,
784802
)
785803

786804
yield data

tests/manual/debugme.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
os.environ["ROBOFLOW_CONFIG_DIR"] = f"{thisdir}/data/.config"
66

77
from roboflow.roboflowpy import _argparser # noqa: E402
8+
from roboflow import Roboflow
89

910
# import requests
1011
# requests.urllib3.disable_warnings()
1112

1213
rootdir = os.path.abspath(f"{thisdir}/../..")
1314
sys.path.append(rootdir)
1415

15-
if __name__ == "__main__":
16+
17+
def run_cli():
1618
parser = _argparser()
1719
# args = parser.parse_args(["login"])
1820
# args = parser.parse_args(f"upload {thisdir}/../datasets/chess -w wolfodorpythontests -p chess".split()) # noqa: E501 // docs
@@ -45,3 +47,32 @@
4547
# f"import -w tonyprivate -p meh-plvrv {thisdir}/../datasets/paligemma/".split() # noqa: E501 // docs
4648
)
4749
args.func(args)
50+
51+
52+
def run_api_train():
53+
rf = Roboflow()
54+
project = rf.workspace("meh3").project("mosquitobao")
55+
# version_number = project.generate_version(
56+
# settings={
57+
# "augmentation": {
58+
# "bbblur": {"pixels": 1.5},
59+
# "image": {"versions": 2},
60+
# },
61+
# "preprocessing": {
62+
# "auto-orient": True,
63+
# },
64+
# }
65+
# )
66+
version_number = "61"
67+
print(version_number)
68+
version = project.version(version_number)
69+
model = version.train(
70+
speed="fast", # Options: "fast" (default) or "accurate" (paid feature)
71+
checkpoint=None, # Use a specific checkpoint to continue training
72+
)
73+
print(model)
74+
75+
76+
if __name__ == "__main__":
77+
# run_cli()
78+
run_api_train()

tests/test_project.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,3 +667,107 @@ def capture_annotation_calls(annotation_path, **kwargs):
667667
finally:
668668
for mock in mocks.values():
669669
mock.stop()
670+
671+
def test_search_with_annotation_job_params(self):
672+
"""Test that annotation_job and annotation_job_id parameters are properly included in search requests"""
673+
# Test 1: Search with annotation_job=True
674+
expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/search?api_key={ROBOFLOW_API_KEY}"
675+
mock_response = {
676+
"results": [
677+
{"id": "image1", "name": "test1.jpg", "created": 1616161616, "labels": ["person"]},
678+
{"id": "image2", "name": "test2.jpg", "created": 1616161617, "labels": ["car"]},
679+
]
680+
}
681+
682+
responses.add(
683+
responses.POST,
684+
expected_url,
685+
json=mock_response,
686+
status=200,
687+
match=[
688+
json_params_matcher(
689+
{
690+
"offset": 0,
691+
"limit": 100,
692+
"batch": False,
693+
"annotation_job": True,
694+
"fields": ["id", "created", "name", "labels"],
695+
}
696+
)
697+
],
698+
)
699+
700+
results = self.project.search(annotation_job=True)
701+
self.assertEqual(len(results), 2)
702+
self.assertEqual(results[0]["id"], "image1")
703+
704+
# Test 2: Search with annotation_job_id
705+
test_job_id = "job_123456"
706+
responses.add(
707+
responses.POST,
708+
expected_url,
709+
json=mock_response,
710+
status=200,
711+
match=[
712+
json_params_matcher(
713+
{
714+
"offset": 0,
715+
"limit": 100,
716+
"batch": False,
717+
"annotation_job_id": test_job_id,
718+
"fields": ["id", "created", "name", "labels"],
719+
}
720+
)
721+
],
722+
)
723+
724+
results = self.project.search(annotation_job_id=test_job_id)
725+
self.assertEqual(len(results), 2)
726+
727+
# Test 3: Search with both parameters
728+
responses.add(
729+
responses.POST,
730+
expected_url,
731+
json=mock_response,
732+
status=200,
733+
match=[
734+
json_params_matcher(
735+
{
736+
"offset": 0,
737+
"limit": 50,
738+
"batch": False,
739+
"annotation_job": False,
740+
"annotation_job_id": test_job_id,
741+
"prompt": "dog",
742+
"fields": ["id", "created", "name", "labels"],
743+
}
744+
)
745+
],
746+
)
747+
748+
results = self.project.search(prompt="dog", annotation_job=False, annotation_job_id=test_job_id, limit=50)
749+
self.assertEqual(len(results), 2)
750+
751+
# Test 4: Verify parameters are not included when None
752+
responses.add(
753+
responses.POST,
754+
expected_url,
755+
json=mock_response,
756+
status=200,
757+
match=[
758+
json_params_matcher(
759+
{
760+
"offset": 0,
761+
"limit": 100,
762+
"batch": False,
763+
"fields": ["id", "created", "name", "labels"],
764+
# annotation_job and annotation_job_id should NOT be in the payload
765+
}
766+
)
767+
],
768+
)
769+
770+
# This should pass because json_params_matcher only checks that the
771+
# specified keys match, it doesn't fail if additional keys are missing
772+
results = self.project.search()
773+
self.assertEqual(len(results), 2)

0 commit comments

Comments
 (0)