diff --git a/CMakeLists.txt b/CMakeLists.txt index 47fab25..c7a085a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,10 +106,11 @@ set(SPARROW_IPC_HEADERS ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/arrow_interface/arrow_schema/private_data.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/chunk_memory_output_stream.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/chunk_memory_serializer.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/compression.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/config.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/sparrow_ipc_version.hpp - ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/compression.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_fixedsizebinary_array.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_interval_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_primitive_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_utils.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp diff --git a/include/sparrow_ipc/deserialize_interval_array.hpp b/include/sparrow_ipc/deserialize_interval_array.hpp new file mode 100644 index 0000000..b8fc725 --- /dev/null +++ b/include/sparrow_ipc/deserialize_interval_array.hpp @@ -0,0 +1,71 @@ +#pragma once + +#include +#include + +#include +#include + +#include "Message_generated.h" +#include "sparrow_ipc/arrow_interface/arrow_array.hpp" +#include "sparrow_ipc/arrow_interface/arrow_schema.hpp" +#include "sparrow_ipc/deserialize_utils.hpp" + +namespace sparrow_ipc +{ + template + [[nodiscard]] sparrow::interval_array deserialize_non_owning_interval_array( + const org::apache::arrow::flatbuf::RecordBatch& record_batch, + std::span body, + std::string_view name, + const std::optional>& metadata, + size_t& buffer_index + ) + { + const std::string_view format = data_type_to_format( + sparrow::detail::get_data_type_from_array>::get() + ); + ArrowSchema schema = make_non_owning_arrow_schema( + format, + name.data(), + metadata, + std::nullopt, + 0, + nullptr, + nullptr + ); + + const auto compression = record_batch.compression(); + std::vector buffers; + + auto validity_buffer_span = utils::get_buffer(record_batch, body, buffer_index); + auto data_buffer_span = utils::get_buffer(record_batch, body, buffer_index); + + if (compression) + { + buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression)); + buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression)); + } + else + { + buffers.emplace_back(validity_buffer_span); + buffers.emplace_back(data_buffer_span); + } + + // TODO bitmap_ptr is not used anymore... Leave it for now, and remove later if no need confirmed + const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length()); + + ArrowArray array = make_arrow_array( + record_batch.length(), + null_count, + 0, + 0, + nullptr, + nullptr, + std::move(buffers) + ); + + sparrow::arrow_proxy ap{std::move(array), std::move(schema)}; + return sparrow::interval_array{std::move(ap)}; + } +} diff --git a/src/deserialize.cpp b/src/deserialize.cpp index d4b98d1..a29d261 100644 --- a/src/deserialize.cpp +++ b/src/deserialize.cpp @@ -3,6 +3,7 @@ #include #include "sparrow_ipc/deserialize_fixedsizebinary_array.hpp" +#include "sparrow_ipc/deserialize_interval_array.hpp" #include "sparrow_ipc/deserialize_primitive_array.hpp" #include "sparrow_ipc/deserialize_variable_size_binary_array.hpp" #include "sparrow_ipc/magic_values.hpp" @@ -10,11 +11,23 @@ namespace sparrow_ipc { + namespace + { + // Integer bit width constants + constexpr int32_t BIT_WIDTH_8 = 8; + constexpr int32_t BIT_WIDTH_16 = 16; + constexpr int32_t BIT_WIDTH_32 = 32; + constexpr int32_t BIT_WIDTH_64 = 64; + + // End-of-stream marker size in bytes + constexpr size_t END_OF_STREAM_MARKER_SIZE = 8; + } const org::apache::arrow::flatbuf::RecordBatch* deserialize_record_batch_message(std::span data, size_t& current_offset) { current_offset += sizeof(uint32_t); - const auto batch_message = org::apache::arrow::flatbuf::GetMessage(data.data() + current_offset); + const auto message_data = data.subspan(current_offset); + const auto* batch_message = org::apache::arrow::flatbuf::GetMessage(message_data.data()); if (batch_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::RecordBatch) { throw std::runtime_error("Expected RecordBatch message, but got a different type."); @@ -27,20 +40,21 @@ namespace sparrow_ipc * * This function processes each field in the schema and deserializes the corresponding * data from the RecordBatch into sparrow::array objects. It handles various Arrow data - * types including primitive types (bool, integers, floating point), binary data, and - * string data with their respective size variants. + * types including primitive types (bool, integers, floating point), binary data, string + * data, fixed-size binary data, and interval types. * * @param record_batch The Apache Arrow FlatBuffer RecordBatch containing the serialized data * @param schema The Apache Arrow FlatBuffer Schema defining the structure and types of the data * @param encapsulated_message The message containing the binary data buffers + * @param field_metadata Metadata associated with each field in the schema * * @return std::vector A vector of deserialized arrays, one for each field in the schema * - * @throws std::runtime_error If an unsupported data type, integer bit width, or floating point precision - * is encountered + * @throws std::runtime_error If an unsupported data type, integer bit width, floating point precision, + * or interval unit is encountered * - * The function maintains a buffer index that is incremented as it processes each field - * to correctly map data buffers to their corresponding arrays. + * @note The function maintains a buffer index that is incremented as it processes each field + * to correctly map data buffers to their corresponding arrays. */ std::vector get_arrays_from_record_batch( const org::apache::arrow::flatbuf::RecordBatch& record_batch, @@ -61,7 +75,7 @@ namespace sparrow_ipc const std::optional>& metadata = field_metadata[field_idx++]; const std::string name = field->name() == nullptr ? "" : field->name()->str(); const auto field_type = field->type_type(); - // TODO rename all the deserialize_non_owning... fcts since this is not correct anymore + const auto deserialize_non_owning_primitive_array_lambda = [&]() { return deserialize_non_owning_primitive_array( @@ -81,7 +95,7 @@ namespace sparrow_ipc break; case org::apache::arrow::flatbuf::Type::Int: { - const auto int_type = field->type_as_Int(); + const auto* int_type = field->type_as_Int(); const auto bit_width = int_type->bitWidth(); const bool is_signed = int_type->is_signed(); @@ -90,11 +104,11 @@ namespace sparrow_ipc switch (bit_width) { // clang-format off - case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - default: throw std::runtime_error("Unsupported integer bit width."); + case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width)); // clang-format on } } @@ -103,11 +117,11 @@ namespace sparrow_ipc switch (bit_width) { // clang-format off - case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - default: throw std::runtime_error("Unsupported integer bit width."); + case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width)); // clang-format on } } @@ -115,7 +129,7 @@ namespace sparrow_ipc break; case org::apache::arrow::flatbuf::Type::FloatingPoint: { - const auto float_type = field->type_as_FloatingPoint(); + const auto* float_type = field->type_as_FloatingPoint(); switch (float_type->precision()) { // clang-format off @@ -129,14 +143,17 @@ namespace sparrow_ipc arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; default: - throw std::runtime_error("Unsupported floating point precision."); + throw std::runtime_error( + "Unsupported floating point precision: " + + std::to_string(static_cast(float_type->precision())) + ); // clang-format on } break; } case org::apache::arrow::flatbuf::Type::FixedSizeBinary: { - const auto fixed_size_binary_field = field->type_as_FixedSizeBinary(); + const auto* fixed_size_binary_field = field->type_as_FixedSizeBinary(); arrays.emplace_back(deserialize_non_owning_fixedwidthbinary( record_batch, encapsulated_message.body(), @@ -191,8 +208,58 @@ namespace sparrow_ipc ) ); break; + case org::apache::arrow::flatbuf::Type::Interval: + { + const auto* interval_type = field->type_as_Interval(); + const org::apache::arrow::flatbuf::IntervalUnit interval_unit = interval_type->unit(); + switch (interval_unit) + { + case org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH: + arrays.emplace_back( + deserialize_non_owning_interval_array( + record_batch, + encapsulated_message.body(), + name, + metadata, + buffer_index + ) + ); + break; + case org::apache::arrow::flatbuf::IntervalUnit::DAY_TIME: + arrays.emplace_back( + deserialize_non_owning_interval_array( + record_batch, + encapsulated_message.body(), + name, + metadata, + buffer_index + ) + ); + break; + case org::apache::arrow::flatbuf::IntervalUnit::MONTH_DAY_NANO: + arrays.emplace_back( + deserialize_non_owning_interval_array( + record_batch, + encapsulated_message.body(), + name, + metadata, + buffer_index + ) + ); + break; + default: + throw std::runtime_error( + "Unsupported interval unit: " + + std::to_string(static_cast(interval_unit)) + ); + } + } + break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error( + "Unsupported field type: " + std::to_string(static_cast(field_type)) + + " for field '" + name + "'" + ); } } return arrays; @@ -206,10 +273,12 @@ namespace sparrow_ipc std::vector fields_nullable; std::vector field_types; std::vector>> fields_metadata; - do + + while (!data.empty()) { - // Check for end-of-stream marker here as data could contain only that (if no record batches present/written) - if (data.size() >= 8 && is_end_of_stream(data.subspan(0, 8))) + // Check for end-of-stream marker + if (data.size() >= END_OF_STREAM_MARKER_SIZE + && is_end_of_stream(data.subspan(0, END_OF_STREAM_MARKER_SIZE))) { break; } @@ -234,11 +303,12 @@ namespace sparrow_ipc for (const auto field : *(schema->fields())) { - if(field != nullptr && field->name() != nullptr) + if (field != nullptr && field->name() != nullptr) { - field_names.emplace_back(field->name()->str()); + field_names.emplace_back(field->name()->str()); } - else { + else + { field_names.emplace_back("_unnamed_"); } fields_nullable.push_back(field->nullable()); @@ -256,12 +326,12 @@ namespace sparrow_ipc { if (schema == nullptr) { - throw std::runtime_error("Schema message is missing."); + throw std::runtime_error("RecordBatch encountered before Schema message."); } - const auto record_batch = message->header_as_RecordBatch(); + const auto* record_batch = message->header_as_RecordBatch(); if (record_batch == nullptr) { - throw std::runtime_error("RecordBatch message is missing."); + throw std::runtime_error("RecordBatch message header is null."); } std::vector arrays = get_arrays_from_record_batch( *record_batch, @@ -269,7 +339,7 @@ namespace sparrow_ipc encapsulated_message, fields_metadata ); - auto names_copy = field_names; // TODO: Remove when issue with the to_vector of record_batch is fixed + auto names_copy = field_names; sparrow::record_batch sp_record_batch(std::move(names_copy), std::move(arrays)); record_batches.emplace_back(std::move(sp_record_batch)); } @@ -277,12 +347,12 @@ namespace sparrow_ipc case org::apache::arrow::flatbuf::MessageHeader::Tensor: case org::apache::arrow::flatbuf::MessageHeader::DictionaryBatch: case org::apache::arrow::flatbuf::MessageHeader::SparseTensor: - throw std::runtime_error("Not supported"); + throw std::runtime_error("Unsupported message type: Tensor, DictionaryBatch, or SparseTensor"); default: throw std::runtime_error("Unknown message header type."); } data = rest; - } while (!data.empty()); + } return record_batches; } } diff --git a/tests/test_de_serialization_with_files.cpp b/tests/test_de_serialization_with_files.cpp index b0243a1..a2887a3 100644 --- a/tests/test_de_serialization_with_files.cpp +++ b/tests/test_de_serialization_with_files.cpp @@ -33,6 +33,7 @@ const std::vector files_paths_to_test = { tests_resources_files_path / "generated_large_binary", tests_resources_files_path / "generated_binary_zerolength", tests_resources_files_path / "generated_binary_no_batches", + tests_resources_files_path / "generated_interval", }; const std::vector files_paths_to_test_with_compression = {