Skip to content

Commit 24aae0a

Browse files
authored
Add arrow binary data type support (#1702)
* Add arrow binary data type support Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test case for string type Signed-off-by: Yong Tang <yong.tang.github@outlook.com> Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
1 parent 03da6dc commit 24aae0a

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

tensorflow_io/core/kernels/arrow/arrow_util.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ Status GetTensorFlowType(std::shared_ptr<::arrow::DataType> dtype,
3333
*out = ::tensorflow::DT_STRING;
3434
return Status::OK();
3535
}
36+
if (dtype->id() == ::arrow::Type::BINARY) {
37+
*out = ::tensorflow::DT_STRING;
38+
return Status::OK();
39+
}
3640
::arrow::Status status =
3741
::arrow::adapters::tensorflow::GetTensorFlowType(dtype, out);
3842
if (!status.ok()) {
@@ -118,6 +122,7 @@ class ArrowAssignSpecImpl : public arrow::ArrayVisitor {
118122
VISIT_PRIMITIVE(arrow::FloatArray)
119123
VISIT_PRIMITIVE(arrow::DoubleArray)
120124
VISIT_PRIMITIVE(arrow::StringArray)
125+
VISIT_PRIMITIVE(arrow::BinaryArray)
121126
#undef VISIT_PRIMITIVE
122127

123128
virtual arrow::Status Visit(const arrow::ListArray& array) override {
@@ -286,6 +291,17 @@ class ArrowAssignTensorImpl : public arrow::ArrayVisitor {
286291
return arrow::Status::OK();
287292
}
288293

294+
virtual arrow::Status Visit(const arrow::BinaryArray& array) override {
295+
auto shape = out_tensor_->shape();
296+
auto output_flat = out_tensor_->flat<tstring>();
297+
298+
for (int64 j = 0; j < shape.num_elements(); ++j) {
299+
output_flat(j) = array.GetString(i_ + j);
300+
}
301+
302+
return arrow::Status::OK();
303+
}
304+
289305
private:
290306
int64 i_;
291307
int32 curr_array_length_;

tests/test_arrow.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,74 @@ def test_arrow_list_feather_columns(self):
12301230

12311231
os.unlink(f.name)
12321232

1233+
def test_arrow_feather_dataset_binary(self):
1234+
"""test_arrow_feather_dataset_binary"""
1235+
import tensorflow_io.arrow as arrow_io
1236+
1237+
from pyarrow import feather as pa_feather
1238+
import pyarrow as pa
1239+
import numpy as np
1240+
1241+
local_path = os.path.join(
1242+
os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.png"
1243+
)
1244+
with open(local_path, "rb") as f:
1245+
data = [f.read()]
1246+
table = pa.Table.from_arrays([data], ["data"])
1247+
1248+
chunk_size = 1000
1249+
with tempfile.NamedTemporaryFile(delete=False) as f:
1250+
pa_feather.write_feather(table, f, chunksize=chunk_size)
1251+
1252+
dataset = arrow_io.ArrowFeatherDataset(
1253+
f.name,
1254+
columns=(0,),
1255+
output_types=(tf.string),
1256+
output_shapes=([]),
1257+
batch_size=32,
1258+
)
1259+
print(dataset.element_spec)
1260+
sample = next(iter(dataset))
1261+
assert sample.shape == [1]
1262+
assert sample.dtype == tf.string
1263+
assert np.all(sample == data[0])
1264+
print(sample.dtype, sample.shape)
1265+
1266+
os.unlink(f.name)
1267+
1268+
def test_arrow_feather_dataset_string(self):
1269+
"""test_arrow_feather_dataset_string"""
1270+
import tensorflow_io.arrow as arrow_io
1271+
1272+
from pyarrow import feather as pa_feather
1273+
import pyarrow as pa
1274+
import numpy as np
1275+
import random
1276+
import string
1277+
1278+
data = ["".join(random.choice(string.printable) for _ in range(10000))]
1279+
table = pa.Table.from_arrays([data], schema=pa.schema([("data", pa.string())]))
1280+
1281+
chunk_size = 1000
1282+
with tempfile.NamedTemporaryFile(delete=False) as f:
1283+
pa_feather.write_feather(table, f, chunksize=chunk_size)
1284+
1285+
dataset = arrow_io.ArrowFeatherDataset(
1286+
f.name,
1287+
columns=(0,),
1288+
output_types=(tf.string),
1289+
output_shapes=([]),
1290+
batch_size=32,
1291+
)
1292+
print(dataset.element_spec)
1293+
sample = next(iter(dataset))
1294+
assert sample.shape == [1]
1295+
assert sample.dtype == tf.string
1296+
assert np.all(sample == data[0])
1297+
print(sample.dtype, sample.shape)
1298+
1299+
os.unlink(f.name)
1300+
12331301

12341302
if __name__ == "__main__":
12351303
test.main()

0 commit comments

Comments
 (0)