Skip to content

Commit 03da6dc

Browse files
authored
Add string support to feather file (#1698)
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
1 parent f31422e commit 03da6dc

File tree

2 files changed

+83
-61
lines changed

2 files changed

+83
-61
lines changed

tensorflow_io/core/kernels/arrow/arrow_kernels.cc

Lines changed: 41 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -544,97 +544,68 @@ class FeatherReadable : public IOReadableInterface {
544544
new SizedRandomAccessFile(env_, filename, memory_data, memory_size));
545545
TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_));
546546

547-
// FEA1.....[metadata][uint32 metadata_length]FEA1
548-
static constexpr const char* kFeatherMagicBytes = "FEA1";
549-
550-
size_t header_length = strlen(kFeatherMagicBytes);
551-
size_t footer_length = sizeof(uint32) + strlen(kFeatherMagicBytes);
552-
553-
string buffer;
554-
buffer.resize(header_length > footer_length ? header_length
555-
: footer_length);
556-
557-
StringPiece result;
558-
559-
TF_RETURN_IF_ERROR(file_->Read(0, header_length, &result, &buffer[0]));
560-
if (memcmp(buffer.data(), kFeatherMagicBytes, header_length) != 0) {
561-
return errors::InvalidArgument("not a feather file");
547+
std::shared_ptr<ArrowRandomAccessFile> feather_file;
548+
feather_file.reset(new ArrowRandomAccessFile(file_.get(), file_size_));
549+
auto maybe_reader = arrow::ipc::feather::Reader::Open(feather_file);
550+
if (!maybe_reader.ok()) {
551+
return errors::Internal(maybe_reader.status().ToString());
562552
}
553+
std::shared_ptr<arrow::ipc::feather::Reader> reader =
554+
maybe_reader.ValueOrDie();
555+
std::shared_ptr<arrow::Schema> schema = reader->schema();
563556

564-
TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length, footer_length,
565-
&result, &buffer[0]));
566-
if (memcmp(buffer.data() + sizeof(uint32), kFeatherMagicBytes,
567-
footer_length - sizeof(uint32)) != 0) {
568-
return errors::InvalidArgument("incomplete feather file");
569-
}
570-
571-
uint32 metadata_length = *reinterpret_cast<const uint32*>(buffer.data());
572-
573-
buffer.resize(metadata_length);
574-
575-
TF_RETURN_IF_ERROR(file_->Read(file_size_ - footer_length - metadata_length,
576-
metadata_length, &result, &buffer[0]));
577-
578-
const ::arrow::ipc::feather::fbs::CTable* table =
579-
::arrow::ipc::feather::fbs::GetCTable(buffer.data());
580-
581-
if (table->version() < ::arrow::ipc::feather::kFeatherV1Version) {
582-
return errors::InvalidArgument("feather file is old: ", table->version(),
583-
" vs. ",
584-
::arrow::ipc::feather::kFeatherV1Version);
557+
std::shared_ptr<arrow::Table> table;
558+
arrow::Status s = reader->Read(&table);
559+
if (!s.ok()) {
560+
return errors::Internal(s.ToString());
585561
}
586562

587-
for (size_t i = 0; i < table->columns()->size(); i++) {
563+
for (int i = 0; i < schema->num_fields(); i++) {
588564
::tensorflow::DataType dtype = ::tensorflow::DataType::DT_INVALID;
589-
switch (table->columns()->Get(i)->values()->type()) {
590-
case ::arrow::ipc::feather::fbs::Type::BOOL:
565+
switch (schema->field(i)->type()->id()) {
566+
case ::arrow::Type::BOOL:
591567
dtype = ::tensorflow::DataType::DT_BOOL;
592568
break;
593-
case ::arrow::ipc::feather::fbs::Type::INT8:
569+
case ::arrow::Type::INT8:
594570
dtype = ::tensorflow::DataType::DT_INT8;
595571
break;
596-
case ::arrow::ipc::feather::fbs::Type::INT16:
572+
case ::arrow::Type::INT16:
597573
dtype = ::tensorflow::DataType::DT_INT16;
598574
break;
599-
case ::arrow::ipc::feather::fbs::Type::INT32:
575+
case ::arrow::Type::INT32:
600576
dtype = ::tensorflow::DataType::DT_INT32;
601577
break;
602-
case ::arrow::ipc::feather::fbs::Type::INT64:
578+
case ::arrow::Type::INT64:
603579
dtype = ::tensorflow::DataType::DT_INT64;
604580
break;
605-
case ::arrow::ipc::feather::fbs::Type::UINT8:
581+
case ::arrow::Type::UINT8:
606582
dtype = ::tensorflow::DataType::DT_UINT8;
607583
break;
608-
case ::arrow::ipc::feather::fbs::Type::UINT16:
584+
case ::arrow::Type::UINT16:
609585
dtype = ::tensorflow::DataType::DT_UINT16;
610586
break;
611-
case ::arrow::ipc::feather::fbs::Type::UINT32:
587+
case ::arrow::Type::UINT32:
612588
dtype = ::tensorflow::DataType::DT_UINT32;
613589
break;
614-
case ::arrow::ipc::feather::fbs::Type::UINT64:
590+
case ::arrow::Type::UINT64:
615591
dtype = ::tensorflow::DataType::DT_UINT64;
616592
break;
617-
case ::arrow::ipc::feather::fbs::Type::FLOAT:
593+
case ::arrow::Type::FLOAT:
618594
dtype = ::tensorflow::DataType::DT_FLOAT;
619595
break;
620-
case ::arrow::ipc::feather::fbs::Type::DOUBLE:
596+
case ::arrow::Type::DOUBLE:
621597
dtype = ::tensorflow::DataType::DT_DOUBLE;
622598
break;
623-
case ::arrow::ipc::feather::fbs::Type::UTF8:
624-
case ::arrow::ipc::feather::fbs::Type::BINARY:
625-
case ::arrow::ipc::feather::fbs::Type::CATEGORY:
626-
case ::arrow::ipc::feather::fbs::Type::TIMESTAMP:
627-
case ::arrow::ipc::feather::fbs::Type::DATE:
628-
case ::arrow::ipc::feather::fbs::Type::TIME:
629-
// case ::arrow::ipc::feather::fbs::Type::LARGE_UTF8:
630-
// case ::arrow::ipc::feather::fbs::Type::LARGE_BINARY:
599+
case ::arrow::Type::BINARY:
600+
dtype = ::tensorflow::DataType::DT_STRING;
601+
break;
631602
default:
632603
break;
633604
}
634605
shapes_.push_back(TensorShape({static_cast<int64>(table->num_rows())}));
635606
dtypes_.push_back(dtype);
636-
columns_.push_back(table->columns()->Get(i)->name()->str());
637-
columns_index_[table->columns()->Get(i)->name()->str()] = i;
607+
columns_.push_back(schema->field(i)->name());
608+
columns_index_[schema->field(i)->name()] = i;
638609
}
639610

640611
return Status::OK();
@@ -751,6 +722,17 @@ class FeatherReadable : public IOReadableInterface {
751722
FEATHER_PROCESS_TYPE(double,
752723
::arrow::NumericArray<::arrow::DoubleType>);
753724
break;
725+
case DT_STRING: {
726+
int64 curr_index = 0;
727+
for (auto chunk : slice->chunks()) {
728+
for (int64_t item = 0; item < chunk->length(); item++) {
729+
value->flat<tstring>()(curr_index) =
730+
(dynamic_cast<::arrow::BinaryArray*>(chunk.get()))
731+
->GetString(item);
732+
curr_index++;
733+
}
734+
}
735+
} break;
754736
default:
755737
return errors::InvalidArgument("data type is not supported: ",
756738
DataTypeString(value->dtype()));

tests/test_feather.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,24 @@
1717

1818
import os
1919
import tempfile
20+
import pytest
2021

22+
import tensorflow as tf
2123
import tensorflow_io as tfio
2224

2325

24-
def test_feather_format():
26+
@pytest.mark.parametrize(
27+
("version"),
28+
[
29+
1,
30+
2,
31+
],
32+
ids=[
33+
"v1",
34+
"v2",
35+
],
36+
)
37+
def test_feather_format(version):
2538
"""test_feather_format"""
2639
import numpy as np
2740
import pandas as pd
@@ -39,7 +52,7 @@ def test_feather_format():
3952
}
4053
df = pd.DataFrame(data).sort_index(axis=1)
4154
with tempfile.NamedTemporaryFile(delete=False) as f:
42-
pa_feather.write_feather(df, f, version=1)
55+
pa_feather.write_feather(df, f, version=version)
4356

4457
feather = tfio.IOTensor.from_feather(f.name)
4558
for column in df.columns:
@@ -50,5 +63,32 @@ def test_feather_format():
5063
os.unlink(f.name)
5164

5265

66+
def test_binary_feather_format():
67+
"""test_binary_feather_format"""
68+
import numpy as np
69+
import pandas as pd
70+
71+
from pyarrow import feather as pa_feather
72+
import pyarrow as pa
73+
74+
local_path = os.path.join(
75+
os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.png"
76+
)
77+
with open(local_path, "rb") as f:
78+
data = [f.read()]
79+
table = pa.Table.from_arrays([data], ["data"])
80+
81+
chunk_size = 1000
82+
with tempfile.NamedTemporaryFile(delete=False) as f:
83+
pa_feather.write_feather(table, f, chunksize=chunk_size)
84+
85+
feather = tfio.IOTensor.from_feather(f.name)
86+
assert feather("data").shape == [1]
87+
assert feather("data").dtype == tf.string
88+
assert np.all(feather("data").to_tensor().numpy() == data[0])
89+
90+
os.unlink(f.name)
91+
92+
5393
if __name__ == "__main__":
5494
test.main()

0 commit comments

Comments
 (0)