Skip to content

Commit 3416746

Browse files
committed
add project.batches() test
1 parent 2ca4d86 commit 3416746

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

labelbox/pagination.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self,
2525
params: Dict[str, str],
2626
dereferencing: Union[List[str], Dict[str, Any]],
2727
obj_class: Union[Type["DbObject"], Callable[[Any, Any], Any]],
28-
cursor_path: Optional[Dict[str, Any]] = None,
28+
cursor_path: Optional[List[str]] = None,
2929
experimental: bool = False):
3030
""" Creates a PaginatedCollection.
3131
@@ -105,7 +105,7 @@ def get_next_page(self) -> Tuple[Dict[str, Any], bool]:
105105

106106
class _CursorPagination(_Pagination):
107107

108-
def __init__(self, cursor_path: Dict[str, Any], *args, **kwargs):
108+
def __init__(self, cursor_path: List[str], *args, **kwargs):
109109
super().__init__(*args, **kwargs)
110110
self.cursor_path = cursor_path
111111
self.next_cursor: Optional[Any] = None

labelbox/schema/project.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -855,12 +855,7 @@ def batches(self) -> PaginatedCollection:
855855
self.client,
856856
query_str, {id_param: self.uid}, ['project', 'batches', 'nodes'],
857857
lambda client, res: Entity.Batch(client, self.uid, res),
858-
cursor_path={
859-
'project': None,
860-
'batches': None,
861-
'pageInfo': None,
862-
'endCursor': None
863-
},
858+
cursor_path=['project', 'batches', 'pageInfo', 'endCursor'],
864859
experimental=True)
865860

866861
def upload_annotations(

tests/integration/test_project.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import requests
66

7-
from labelbox import Project, LabelingFrontend
7+
from labelbox import Project, LabelingFrontend, Dataset
88
from labelbox.exceptions import InvalidQueryError
99

1010

@@ -201,3 +201,22 @@ def test_queue_mode(configured_project: Project):
201201
) == configured_project.QueueMode.Dataset
202202
configured_project.update(queue_mode=configured_project.QueueMode.Batch)
203203
assert configured_project.queue_mode() == configured_project.QueueMode.Batch
204+
205+
206+
def test_batches(configured_project: Project, dataset: Dataset, image_url):
207+
task = dataset.create_data_rows([
208+
{
209+
"row_data": image_url,
210+
"external_id": "my-image"
211+
},
212+
] * 2)
213+
task.wait_till_done()
214+
configured_project.update(queue_mode=configured_project.QueueMode.Batch)
215+
data_rows = [dr.uid for dr in list(dataset.export_data_rows())]
216+
batch_one = 'batch one'
217+
batch_two = 'batch two'
218+
configured_project.create_batch(batch_one, [data_rows[0]])
219+
configured_project.create_batch(batch_two, [data_rows[1]])
220+
221+
names = set([batch.name for batch in list(configured_project.batches())])
222+
assert names == set([batch_one, batch_two])

0 commit comments

Comments
 (0)