Skip to content

Commit 62f116b

Browse files
committed
Take colormap argument
1 parent 6d10d8b commit 62f116b

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"The Labelbox python package."
22

3-
__version__ = '0.0.5'
3+
__version__ = '0.0.5.dev1'

labelbox/lbx.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from io import BytesIO
2828
import itertools
2929
import struct
30+
from typing import List
3031

3132
import numpy as np
3233
from PIL import Image
@@ -36,11 +37,14 @@
3637
_HEADER_LENGTH = 6 * 4
3738

3839

39-
def encode(image_in: Image):
40+
def encode(image_in: Image, colormap: List[np.array]):
4041
"""Converts a RGB `Image` to a `io.BytesIO` with LBX encoded data.
4142
4243
Args:
4344
image_in: The image to encode.
45+
colormap: Ordered list of `np.array`s each of length 3 representing
46+
a RGB color. The ordering of this list determines which colors
47+
map to which class labels in the project ontology.
4448
4549
Returns:
4650
A `io.BytesIO` containing the LBX encoded image.
@@ -49,8 +53,8 @@ def encode(image_in: Image):
4953
pixel_words = np.array(image).reshape(-1, 4)
5054
pixel_words.flags.writeable = False
5155

52-
colormap = list(
53-
filter(lambda x: not np.all(x == _BACKGROUND_RGBA), np.unique(pixel_words, axis=0)))
56+
colormap = [np.zeros(4, dtype=np.uint8)] + \
57+
list(map(lambda color: np.append(color, 255).astype(np.uint8), colormap))
5458

5559
input_byte_len = len(np.array(image).flat)
5660
buff = BytesIO(bytes([0] * (len(colormap) * 4 + input_byte_len)))

tests/test_lbx.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,20 @@ def test_lbx_decode(lbx_sample):
2525

2626

2727
def test_lbx_encode(im_png):
28-
lbx_encoded = lbx.encode(im_png)
28+
colormap = [
29+
np.array([0, 0, 128]),
30+
np.array([0, 128, 0]),
31+
]
32+
lbx_encoded = lbx.encode(im_png, colormap)
2933
version, width, height = map(lambda x: x[0], struct.iter_unpack('<i', lbx_encoded.read(12)))
3034
assert version == 1
3135
assert width == 500
32-
assert height == 600
36+
assert height == 375
3337

3438

3539
def test_identity(im_png):
36-
assert np.all(np.array(lbx.decode(lbx.encode(im_png))) == np.array(im_png))
40+
colormap = [
41+
np.array([0, 0, 128]),
42+
np.array([0, 128, 0]),
43+
]
44+
assert np.all(np.array(lbx.decode(lbx.encode(im_png, colormap))) == np.array(im_png))

0 commit comments

Comments
 (0)