Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,11 @@ set(SPARROW_IPC_HEADERS
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/arrow_interface/arrow_schema/private_data.hpp
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/chunk_memory_output_stream.hpp
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/chunk_memory_serializer.hpp
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/compression.hpp
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/config.hpp
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/sparrow_ipc_version.hpp
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/compression.hpp
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_fixedsizebinary_array.hpp
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_interval_array.hpp
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_primitive_array.hpp
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_utils.hpp
${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp
Expand Down
71 changes: 71 additions & 0 deletions include/sparrow_ipc/deserialize_interval_array.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#pragma once

#include <optional>
#include <vector>

#include <sparrow/arrow_interface/arrow_array_schema_proxy.hpp>
#include <sparrow/interval_array.hpp>

#include "Message_generated.h"
#include "sparrow_ipc/arrow_interface/arrow_array.hpp"
#include "sparrow_ipc/arrow_interface/arrow_schema.hpp"
#include "sparrow_ipc/deserialize_utils.hpp"

namespace sparrow_ipc
{
template <typename T>
[[nodiscard]] sparrow::interval_array<T> deserialize_non_owning_interval_array(
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
std::span<const uint8_t> body,
std::string_view name,
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
size_t& buffer_index
)
{
const std::string_view format = data_type_to_format(
sparrow::detail::get_data_type_from_array<sparrow::interval_array<T>>::get()
);
ArrowSchema schema = make_non_owning_arrow_schema(
format,
name.data(),
metadata,
std::nullopt,
0,
nullptr,
nullptr
);

const auto compression = record_batch.compression();
std::vector<arrow_array_private_data::optionally_owned_buffer> buffers;

auto validity_buffer_span = utils::get_buffer(record_batch, body, buffer_index);
auto data_buffer_span = utils::get_buffer(record_batch, body, buffer_index);

if (compression)
{
buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression));
buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression));
}
else
{
buffers.emplace_back(validity_buffer_span);
buffers.emplace_back(data_buffer_span);
}

// TODO bitmap_ptr is not used anymore... Leave it for now, and remove later if no need confirmed
const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length());

ArrowArray array = make_arrow_array<arrow_array_private_data>(
record_batch.length(),
null_count,
0,
0,
nullptr,
nullptr,
std::move(buffers)
);

sparrow::arrow_proxy ap{std::move(array), std::move(schema)};
return sparrow::interval_array<T>{std::move(ap)};
}
}
140 changes: 105 additions & 35 deletions src/deserialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,31 @@
#include <sparrow/types/data_type.hpp>

#include "sparrow_ipc/deserialize_fixedsizebinary_array.hpp"
#include "sparrow_ipc/deserialize_interval_array.hpp"
#include "sparrow_ipc/deserialize_primitive_array.hpp"
#include "sparrow_ipc/deserialize_variable_size_binary_array.hpp"
#include "sparrow_ipc/magic_values.hpp"
#include "sparrow_ipc/metadata.hpp"

namespace sparrow_ipc
{
namespace
{
// Integer bit width constants
constexpr int32_t BIT_WIDTH_8 = 8;
constexpr int32_t BIT_WIDTH_16 = 16;
constexpr int32_t BIT_WIDTH_32 = 32;
constexpr int32_t BIT_WIDTH_64 = 64;

// End-of-stream marker size in bytes
constexpr size_t END_OF_STREAM_MARKER_SIZE = 8;
}
const org::apache::arrow::flatbuf::RecordBatch*
deserialize_record_batch_message(std::span<const uint8_t> data, size_t& current_offset)
{
current_offset += sizeof(uint32_t);
const auto batch_message = org::apache::arrow::flatbuf::GetMessage(data.data() + current_offset);
const auto message_data = data.subspan(current_offset);
const auto* batch_message = org::apache::arrow::flatbuf::GetMessage(message_data.data());
if (batch_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::RecordBatch)
{
throw std::runtime_error("Expected RecordBatch message, but got a different type.");
Expand All @@ -27,20 +40,21 @@ namespace sparrow_ipc
*
* This function processes each field in the schema and deserializes the corresponding
* data from the RecordBatch into sparrow::array objects. It handles various Arrow data
* types including primitive types (bool, integers, floating point), binary data, and
* string data with their respective size variants.
* types including primitive types (bool, integers, floating point), binary data, string
* data, fixed-size binary data, and interval types.
*
* @param record_batch The Apache Arrow FlatBuffer RecordBatch containing the serialized data
* @param schema The Apache Arrow FlatBuffer Schema defining the structure and types of the data
* @param encapsulated_message The message containing the binary data buffers
* @param field_metadata Metadata associated with each field in the schema
*
* @return std::vector<sparrow::array> A vector of deserialized arrays, one for each field in the schema
*
* @throws std::runtime_error If an unsupported data type, integer bit width, or floating point precision
* is encountered
* @throws std::runtime_error If an unsupported data type, integer bit width, floating point precision,
* or interval unit is encountered
*
* The function maintains a buffer index that is incremented as it processes each field
* to correctly map data buffers to their corresponding arrays.
* @note The function maintains a buffer index that is incremented as it processes each field
* to correctly map data buffers to their corresponding arrays.
*/
std::vector<sparrow::array> get_arrays_from_record_batch(
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
Expand All @@ -61,7 +75,7 @@ namespace sparrow_ipc
const std::optional<std::vector<sparrow::metadata_pair>>& metadata = field_metadata[field_idx++];
const std::string name = field->name() == nullptr ? "" : field->name()->str();
const auto field_type = field->type_type();
// TODO rename all the deserialize_non_owning... fcts since this is not correct anymore

const auto deserialize_non_owning_primitive_array_lambda = [&]<typename T>()
{
return deserialize_non_owning_primitive_array<T>(
Expand All @@ -81,7 +95,7 @@ namespace sparrow_ipc
break;
case org::apache::arrow::flatbuf::Type::Int:
{
const auto int_type = field->type_as_Int();
const auto* int_type = field->type_as_Int();
const auto bit_width = int_type->bitWidth();
const bool is_signed = int_type->is_signed();

Expand All @@ -90,11 +104,11 @@ namespace sparrow_ipc
switch (bit_width)
{
// clang-format off
case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int8_t>()); break;
case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int16_t>()); break;
case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int32_t>()); break;
case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int64_t>()); break;
default: throw std::runtime_error("Unsupported integer bit width.");
case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int8_t>()); break;
case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int16_t>()); break;
case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int32_t>()); break;
case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int64_t>()); break;
default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width));
// clang-format on
}
}
Expand All @@ -103,19 +117,19 @@ namespace sparrow_ipc
switch (bit_width)
{
// clang-format off
case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint8_t>()); break;
case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint16_t>()); break;
case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint32_t>()); break;
case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint64_t>()); break;
default: throw std::runtime_error("Unsupported integer bit width.");
case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint8_t>()); break;
case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint16_t>()); break;
case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint32_t>()); break;
case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint64_t>()); break;
default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width));
// clang-format on
}
}
}
break;
case org::apache::arrow::flatbuf::Type::FloatingPoint:
{
const auto float_type = field->type_as_FloatingPoint();
const auto* float_type = field->type_as_FloatingPoint();
switch (float_type->precision())
{
// clang-format off
Expand All @@ -129,14 +143,17 @@ namespace sparrow_ipc
arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<double>());
break;
default:
throw std::runtime_error("Unsupported floating point precision.");
throw std::runtime_error(
"Unsupported floating point precision: "
+ std::to_string(static_cast<int>(float_type->precision()))
);
// clang-format on
}
break;
}
case org::apache::arrow::flatbuf::Type::FixedSizeBinary:
{
const auto fixed_size_binary_field = field->type_as_FixedSizeBinary();
const auto* fixed_size_binary_field = field->type_as_FixedSizeBinary();
arrays.emplace_back(deserialize_non_owning_fixedwidthbinary(
record_batch,
encapsulated_message.body(),
Expand Down Expand Up @@ -191,8 +208,58 @@ namespace sparrow_ipc
)
);
break;
case org::apache::arrow::flatbuf::Type::Interval:
{
const auto* interval_type = field->type_as_Interval();
const org::apache::arrow::flatbuf::IntervalUnit interval_unit = interval_type->unit();
switch (interval_unit)
{
case org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH:
arrays.emplace_back(
deserialize_non_owning_interval_array<sparrow::chrono::months>(
record_batch,
encapsulated_message.body(),
name,
metadata,
buffer_index
)
);
break;
case org::apache::arrow::flatbuf::IntervalUnit::DAY_TIME:
arrays.emplace_back(
deserialize_non_owning_interval_array<sparrow::days_time_interval>(
record_batch,
encapsulated_message.body(),
name,
metadata,
buffer_index
)
);
break;
case org::apache::arrow::flatbuf::IntervalUnit::MONTH_DAY_NANO:
arrays.emplace_back(
deserialize_non_owning_interval_array<sparrow::month_day_nanoseconds_interval>(
record_batch,
encapsulated_message.body(),
name,
metadata,
buffer_index
)
);
break;
default:
throw std::runtime_error(
"Unsupported interval unit: "
+ std::to_string(static_cast<int>(interval_unit))
);
}
}
break;
default:
throw std::runtime_error("Unsupported type.");
throw std::runtime_error(
"Unsupported field type: " + std::to_string(static_cast<int>(field_type))
+ " for field '" + name + "'"
);
}
}
return arrays;
Expand All @@ -206,10 +273,12 @@ namespace sparrow_ipc
std::vector<bool> fields_nullable;
std::vector<sparrow::data_type> field_types;
std::vector<std::optional<std::vector<sparrow::metadata_pair>>> fields_metadata;
do

while (!data.empty())
{
// Check for end-of-stream marker here as data could contain only that (if no record batches present/written)
if (data.size() >= 8 && is_end_of_stream(data.subspan(0, 8)))
// Check for end-of-stream marker
if (data.size() >= END_OF_STREAM_MARKER_SIZE
&& is_end_of_stream(data.subspan(0, END_OF_STREAM_MARKER_SIZE)))
{
break;
}
Expand All @@ -234,11 +303,12 @@ namespace sparrow_ipc

for (const auto field : *(schema->fields()))
{
if(field != nullptr && field->name() != nullptr)
if (field != nullptr && field->name() != nullptr)
{
field_names.emplace_back(field->name()->str());
field_names.emplace_back(field->name()->str());
}
else {
else
{
field_names.emplace_back("_unnamed_");
}
fields_nullable.push_back(field->nullable());
Expand All @@ -256,33 +326,33 @@ namespace sparrow_ipc
{
if (schema == nullptr)
{
throw std::runtime_error("Schema message is missing.");
throw std::runtime_error("RecordBatch encountered before Schema message.");
}
const auto record_batch = message->header_as_RecordBatch();
const auto* record_batch = message->header_as_RecordBatch();
if (record_batch == nullptr)
{
throw std::runtime_error("RecordBatch message is missing.");
throw std::runtime_error("RecordBatch message header is null.");
}
std::vector<sparrow::array> arrays = get_arrays_from_record_batch(
*record_batch,
*schema,
encapsulated_message,
fields_metadata
);
auto names_copy = field_names; // TODO: Remove when issue with the to_vector of record_batch is fixed
auto names_copy = field_names;
sparrow::record_batch sp_record_batch(std::move(names_copy), std::move(arrays));
record_batches.emplace_back(std::move(sp_record_batch));
}
break;
case org::apache::arrow::flatbuf::MessageHeader::Tensor:
case org::apache::arrow::flatbuf::MessageHeader::DictionaryBatch:
case org::apache::arrow::flatbuf::MessageHeader::SparseTensor:
throw std::runtime_error("Not supported");
throw std::runtime_error("Unsupported message type: Tensor, DictionaryBatch, or SparseTensor");
default:
throw std::runtime_error("Unknown message header type.");
}
data = rest;
} while (!data.empty());
}
return record_batches;
}
}
1 change: 1 addition & 0 deletions tests/test_de_serialization_with_files.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const std::vector<std::filesystem::path> files_paths_to_test = {
tests_resources_files_path / "generated_large_binary",
tests_resources_files_path / "generated_binary_zerolength",
tests_resources_files_path / "generated_binary_no_batches",
tests_resources_files_path / "generated_interval",
};

const std::vector<std::filesystem::path> files_paths_to_test_with_compression = {
Expand Down
Loading