|
| 1 | +import json |
| 2 | +import random |
| 3 | + |
| 4 | +from labelbox import StreamType, JsonConverter |
| 5 | + |
| 6 | + |
| 7 | +class TestExportEmbeddings: |
| 8 | + |
| 9 | + def test_export_embeddings_precomputed(self, client, dataset, environ, |
| 10 | + image_url): |
| 11 | + data_row_specs = [{ |
| 12 | + "row_data": image_url, |
| 13 | + "external_id": "image", |
| 14 | + }] |
| 15 | + task = dataset.create_data_rows(data_row_specs) |
| 16 | + task.wait_till_done() |
| 17 | + export_task = dataset.export(params={"embeddings": True}) |
| 18 | + export_task.wait_till_done() |
| 19 | + assert export_task.status == "COMPLETE" |
| 20 | + assert export_task.has_result() |
| 21 | + assert export_task.has_errors() is False |
| 22 | + |
| 23 | + results = [] |
| 24 | + export_task.get_stream(converter=JsonConverter(), |
| 25 | + stream_type=StreamType.RESULT).start( |
| 26 | + stream_handler=lambda output: results.append( |
| 27 | + json.loads(output.json_str))) |
| 28 | + |
| 29 | + assert len(results) == len(data_row_specs) |
| 30 | + |
| 31 | + result = results[0] |
| 32 | + assert "embeddings" in result |
| 33 | + assert len(result["embeddings"]) > 0 |
| 34 | + assert result["embeddings"][0][ |
| 35 | + "name"] == "Image Embedding V2 (CLIP ViT-B/32)" |
| 36 | + assert len(result["embeddings"][0]["values"]) == 1 |
| 37 | + |
| 38 | + def test_export_embeddings_custom(self, client, dataset, image_url, |
| 39 | + embedding): |
| 40 | + vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)] |
| 41 | + import_task = dataset.create_data_rows([{ |
| 42 | + "row_data": image_url, |
| 43 | + "embeddings": [{ |
| 44 | + "embedding_id": embedding.id, |
| 45 | + "vector": vector, |
| 46 | + }], |
| 47 | + }]) |
| 48 | + import_task.wait_till_done() |
| 49 | + assert import_task.status == "COMPLETE" |
| 50 | + |
| 51 | + export_task = dataset.export(params={"embeddings": True}) |
| 52 | + export_task.wait_till_done() |
| 53 | + assert export_task.status == "COMPLETE" |
| 54 | + assert export_task.has_result() |
| 55 | + assert export_task.has_errors() is False |
| 56 | + |
| 57 | + results = [] |
| 58 | + export_task.get_stream(converter=JsonConverter(), |
| 59 | + stream_type=StreamType.RESULT).start( |
| 60 | + stream_handler=lambda output: results.append( |
| 61 | + json.loads(output.json_str))) |
| 62 | + |
| 63 | + assert len(results) == 1 |
| 64 | + assert "embeddings" in results[0] |
| 65 | + assert (len(results[0]["embeddings"]) |
| 66 | + >= 1) # should at least contain the custom embedding |
| 67 | + for emb in results[0]["embeddings"]: |
| 68 | + if emb["id"] == embedding.id: |
| 69 | + assert emb["name"] == embedding.name |
| 70 | + assert emb["dimensions"] == embedding.dims |
| 71 | + assert emb["is_custom"] == True |
| 72 | + assert len(emb["values"]) == 1 |
| 73 | + assert emb["values"][0]["value"] == vector |
0 commit comments