diff --git a/shell_wrapper/BUILD b/shell_wrapper/BUILD index 97c47fb..56f35ff 100644 --- a/shell_wrapper/BUILD +++ b/shell_wrapper/BUILD @@ -145,6 +145,66 @@ rust_cxx_bridge( deps = [":shell_types_cc"], ) +# Serialization support +cc_library( + name = "shell_serialization_cc", + srcs = ["shell_serialization.cc"], + hdrs = [ + "shell_serialization.h", + ], + deps = [ + ":shell_serialization_cxx/include", + ":shell_types_cxx", + ":status_cc", + ":status_cxx", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings:string_view", + "@shell-encryption//shell_encryption/rns:serialization_cc_proto", + "@cxx.rs//:core", + ], +) + +cc_test( + name = "shell_serialization_test", + srcs = ["shell_serialization_test.cc"], + deps = [ + ":shell_serialization_cc", + ":shell_types_cc", + ":shell_types_cxx", + ":status_cc", + ":status_matchers", + "@googletest//:gtest_main", + "@abseil-cpp//absl/status", + "@shell-encryption//shell_encryption/rns:serialization_cc_proto", + "@shell-encryption//shell_encryption/testing:testing_prng", + "@cxx.rs//:core", + ], +) + +rust_library( + name = "shell_serialization", + srcs = ["shell_serialization.rs"], + deps = [ + ":shell_serialization_cc", + ":shell_types", + ":status", + "@protobuf//rust:protobuf", + "@shell-encryption//shell_encryption/rns:serialization_rust_proto", + "@cxx.rs//:cxx", + ], +) + +rust_cxx_bridge( + name = "shell_serialization_cxx", + src = "shell_serialization.rs", + deps = [ + ":shell_serialization_cc", + ":shell_types_cc", + ], +) + cc_library( name = "testing_utils", testonly = 1, diff --git a/shell_wrapper/shell_serialization.cc b/shell_wrapper/shell_serialization.cc new file mode 100644 index 0000000..22da0bd --- /dev/null +++ b/shell_wrapper/shell_serialization.cc @@ -0,0 +1,73 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "shell_wrapper/shell_serialization.h" + +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "include/cxx.h" +#include "shell_encryption/rns/serialization.pb.h" +#include "shell_wrapper/shell_types.rs.h" +#include "shell_wrapper/status.h" +#include "shell_wrapper/status.rs.h" + +using secure_aggregation::MakeFfiStatus; + +FfiStatus SerializeRnsPolynomialToBytes(const RnsPolynomialWrapper* poly, + ModuliWrapper moduli, + std::unique_ptr& out) { + if (poly == nullptr || poly->ptr == nullptr || moduli.moduli == nullptr) { + return MakeFfiStatus(absl::InvalidArgumentError( + "All pointer arguments and their wrapped pointers must be non-null.")); + } + auto serialized = poly->ptr->Serialize({moduli.moduli, moduli.len}); + if (!serialized.ok()) { + return MakeFfiStatus(serialized.status()); + } + std::string buffer; + serialized->SerializeToString(&buffer); + out = std::make_unique(std::move(buffer)); + return {}; +} + +FfiStatus DeserializeRnsPolynomialFromBytes( + rust::Slice serialized_poly, ModuliWrapper moduli, + RnsPolynomialWrapper* out) { + if (out == nullptr || out->ptr == nullptr || moduli.moduli == nullptr) { + return MakeFfiStatus(absl::InvalidArgumentError( + "All pointer arguments and their wrapped pointers must be non-null.")); + } + rlwe::SerializedRnsPolynomial serialized_poly_proto; + if (!serialized_poly_proto.ParseFromString(absl::string_view( + reinterpret_cast(serialized_poly.data()), + serialized_poly.size()))) { + return MakeFfiStatus(absl::InvalidArgumentError( + "Failed to parse serialized RNS polynomial")); + } + auto poly = secure_aggregation::RnsPolynomial::Deserialize( + serialized_poly_proto, {moduli.moduli, moduli.len}); + if (!poly.ok()) { + return MakeFfiStatus(poly.status()); + } + out->ptr = std::make_unique( + std::move(poly.value())); + return MakeFfiStatus(); +} diff --git a/shell_wrapper/shell_serialization.h b/shell_wrapper/shell_serialization.h new file mode 100644 index 0000000..8c7cc5a --- /dev/null +++ b/shell_wrapper/shell_serialization.h @@ -0,0 +1,41 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SECURE_AGGREGATION_SHELL_WRAPPER_SHELL_SERIALIZATION_H_ +#define SECURE_AGGREGATION_SHELL_WRAPPER_SHELL_SERIALIZATION_H_ + +#include +#include +#include + +#include "include/cxx.h" +#include "shell_wrapper/shell_serialization.rs.h" +#include "shell_wrapper/shell_types.rs.h" +#include "shell_wrapper/status.rs.h" + +extern "C" { + +FfiStatus SerializeRnsPolynomialToBytes(const RnsPolynomialWrapper* poly, + ModuliWrapper moduli, + std::unique_ptr& out); + +FfiStatus DeserializeRnsPolynomialFromBytes( + rust::Slice serialized_poly, ModuliWrapper moduli, + RnsPolynomialWrapper* out); + +} // extern "C" + +#endif // SECURE_AGGREGATION_SHELL_WRAPPER_SHELL_SERIALIZATION_H_ diff --git a/shell_wrapper/shell_serialization.rs b/shell_wrapper/shell_serialization.rs new file mode 100644 index 0000000..b1d1e92 --- /dev/null +++ b/shell_wrapper/shell_serialization.rs @@ -0,0 +1,91 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Rust wrapper for serialization support of SHELL types. + +use protobuf::prelude::*; +use serialization_rust_proto::SerializedRnsPolynomial; +use shell_types::{create_empty_rns_polynomial, Moduli, RnsPolynomial}; +use status::{StatusError, StatusErrorCode}; + +#[cxx::bridge] +mod ffi { + unsafe extern "C++" { + include!("shell_wrapper/shell_serialization.h"); + include!("shell_wrapper/shell_types.h"); + + type FfiStatus = shell_types::ffi::FfiStatus; + type ModuliWrapper = shell_types::ffi::ModuliWrapper; + type RnsPolynomialWrapper = shell_types::ffi::RnsPolynomialWrapper; + + pub unsafe fn SerializeRnsPolynomialToBytes( + poly: *const RnsPolynomialWrapper, + moduli: ModuliWrapper, + out: &mut UniquePtr, + ) -> FfiStatus; + + pub unsafe fn DeserializeRnsPolynomialFromBytes( + serialized_poly: &[u8], + moduli: ModuliWrapper, + out: *mut RnsPolynomialWrapper, + ) -> FfiStatus; + + } +} + +use status::rust_status_from_cpp; + +// Serialize a RnsPolynomial to a SerializedRnsPolynomial proto. +pub fn serialize_rns_polynomial( + poly: &RnsPolynomial, + moduli: &Moduli, +) -> Result { + let mut out = cxx::UniquePtr::null(); + // SAFETY: No lifetime constraints (no references are kept by the C++ function). + // `SerializeRnsPolynomialToBytes` allocates a C++ string to write the proto bytes to, and assigns + // the string to `out`. + rust_status_from_cpp(unsafe { + ffi::SerializeRnsPolynomialToBytes(poly, moduli.moduli, &mut out) + })?; + SerializedRnsPolynomial::parse(out.as_bytes()).map_err(|parse_error| { + StatusError::new_with_current_location( + StatusErrorCode::Internal, + format!("{parse_error:?}"), + ) + }) +} + +// Deserialize a SerializedRnsPolynomial proto to a RnsPolynomial. +pub fn deserialize_rns_polynomial( + serialized: SerializedRnsPolynomial, + moduli: &Moduli, +) -> Result { + let serialized_bytes = serialized.serialize().map_err(|serialize_error| { + StatusError::new_with_current_location( + StatusErrorCode::Internal, + format!("{serialize_error:?}"), + ) + })?; + + // SAFETY: No lifetime constraints (`create_empty_rns_polynomial` creates and returns an empty + // C++ object). + let mut poly = unsafe { create_empty_rns_polynomial() }; + + // SAFETY: No lifetime constraints (no references are kept by the C++ function). + // `DeserializeRnsPolynomialFromBytes` allocates a C++ RnsPolynomial object and assigns it to `poly`. + rust_status_from_cpp(unsafe { + ffi::DeserializeRnsPolynomialFromBytes(&serialized_bytes, moduli.moduli, &mut poly) + })?; + Ok(poly) +} diff --git a/shell_wrapper/shell_serialization_test.cc b/shell_wrapper/shell_serialization_test.cc new file mode 100644 index 0000000..1cab4d5 --- /dev/null +++ b/shell_wrapper/shell_serialization_test.cc @@ -0,0 +1,165 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "shell_wrapper/shell_serialization.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "include/cxx.h" +#include "shell_encryption/rns/serialization.pb.h" +#include "shell_encryption/testing/testing_prng.h" +#include "shell_wrapper/shell_aliases.h" +#include "shell_wrapper/shell_types.h" +#include "shell_wrapper/shell_types.rs.h" +#include "shell_wrapper/status.h" +#include "shell_wrapper/status_matchers.h" + +namespace secure_aggregation { +namespace { + +using secure_aggregation::secagg_internal::StatusIs; +using ::testing::HasSubstr; + +constexpr int kLogN = 12; +const std::vector kQs = {1125899906826241ULL, 1125899906629633ULL}; + +TEST(ShellSerializationTest, SerializeRnsPolynomialToBytesFailsOnNullptr) { + constexpr int kT = 2; // Dummy plaintext modulus. + SECAGG_ASSERT_OK_AND_ASSIGN(auto rns_context, + RnsContext::Create(kLogN, kQs, + /*ps=*/{}, kT)); + auto moduli = rns_context.MainPrimeModuli(); + auto moduli_wrapper = + ModuliWrapper{.moduli = moduli.data(), .len = moduli.size()}; + + auto serialized_bytes = std::make_unique(); + EXPECT_THAT( + UnwrapFfiStatus(SerializeRnsPolynomialToBytes( + /*poly=*/nullptr, moduli_wrapper, serialized_bytes)), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("non-null"))); + + RnsPolynomialWrapper null_poly_wrapper = {.ptr = nullptr}; + EXPECT_THAT( + UnwrapFfiStatus(SerializeRnsPolynomialToBytes( + &null_poly_wrapper, moduli_wrapper, serialized_bytes)), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("non-null"))); + + SECAGG_ASSERT_OK_AND_ASSIGN(auto poly, + RnsPolynomial::CreateZero(kLogN, moduli)); + RnsPolynomialWrapper poly_wrapper = { + .ptr = std::make_unique(std::move(poly))}; + EXPECT_THAT( + UnwrapFfiStatus(SerializeRnsPolynomialToBytes( + &poly_wrapper, ModuliWrapper{.moduli = nullptr, .len = 0}, + serialized_bytes)), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("non-null"))); +} + +TEST(ShellSerializationTest, DeserializeRnsPolynomialFromBytesFailsOnNullptr) { + constexpr int kT = 2; // Dummy plaintext modulus. + SECAGG_ASSERT_OK_AND_ASSIGN(auto rns_context, + RnsContext::Create(kLogN, kQs, + /*ps=*/{}, kT)); + auto moduli = rns_context.MainPrimeModuli(); + auto moduli_wrapper = + ModuliWrapper{.moduli = moduli.data(), .len = moduli.size()}; + std::string empty_serialized_bytes; + rust::Slice empty_serialized( + reinterpret_cast(empty_serialized_bytes.data()), 0); + EXPECT_THAT( + UnwrapFfiStatus(DeserializeRnsPolynomialFromBytes( + empty_serialized, moduli_wrapper, /*out=*/nullptr)), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("non-null"))); + + RnsPolynomialWrapper null_poly_wrapper = {.ptr = nullptr}; + EXPECT_THAT( + UnwrapFfiStatus(DeserializeRnsPolynomialFromBytes( + empty_serialized, moduli_wrapper, &null_poly_wrapper)), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("non-null"))); + + SECAGG_ASSERT_OK_AND_ASSIGN(auto poly, + RnsPolynomial::CreateZero(kLogN, moduli)); + RnsPolynomialWrapper poly_wrapper = { + .ptr = std::make_unique(std::move(poly))}; + EXPECT_THAT( + UnwrapFfiStatus(DeserializeRnsPolynomialFromBytes( + empty_serialized, ModuliWrapper{.moduli = nullptr, .len = 0}, + /*out=*/nullptr)), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("non-null"))); +} + +// Tests that the output of SerializeRnsPolynomialToBytes can be deserialized +// to the same RnsPolynomial. +TEST(ShellSerializationTest, SerializeRnsPolynomialToBytes) { + constexpr int kT = 2; // Dummy plaintext modulus. + SECAGG_ASSERT_OK_AND_ASSIGN(auto rns_context, + RnsContext::Create(kLogN, kQs, + /*ps=*/{}, kT)); + auto moduli = rns_context.MainPrimeModuli(); + auto moduli_wrapper = + ModuliWrapper{.moduli = moduli.data(), .len = moduli.size()}; + auto prng = rlwe::testing::TestingPrng(0); + SECAGG_ASSERT_OK_AND_ASSIGN( + auto poly, RnsPolynomial::SampleUniform(kLogN, &prng, moduli)); + RnsPolynomialWrapper poly_wrapper = { + .ptr = std::make_unique(std::move(poly))}; + auto serialized_bytes = std::make_unique(); + SECAGG_EXPECT_OK(UnwrapFfiStatus(SerializeRnsPolynomialToBytes( + &poly_wrapper, moduli_wrapper, serialized_bytes))); + + // Create a proto from the serialized bytes and then deserialize it. + rlwe::SerializedRnsPolynomial serialized_poly_proto; + ASSERT_TRUE(serialized_poly_proto.ParseFromString(*serialized_bytes)); + SECAGG_ASSERT_OK_AND_ASSIGN( + auto deserialized_poly, + RnsPolynomial::Deserialize(serialized_poly_proto, moduli)); + EXPECT_EQ(deserialized_poly, *(poly_wrapper.ptr)); +} + +TEST(ShellSerializationTest, DeserializeRnsPolynomialFromBytes) { + constexpr int kT = 2; // Dummy plaintext modulus. + SECAGG_ASSERT_OK_AND_ASSIGN(auto rns_context, + RnsContext::Create(kLogN, kQs, + /*ps=*/{}, kT)); + auto moduli = rns_context.MainPrimeModuli(); + auto moduli_wrapper = + ModuliWrapper{.moduli = moduli.data(), .len = moduli.size()}; + auto prng = rlwe::testing::TestingPrng(0); + SECAGG_ASSERT_OK_AND_ASSIGN( + auto poly, RnsPolynomial::SampleUniform(kLogN, &prng, moduli)); + + // Serialize the RnsPolynomial to bytes. + SECAGG_ASSERT_OK_AND_ASSIGN(auto serialized_proto, poly.Serialize(moduli)); + std::string serialized_bytes; + serialized_proto.SerializeToString(&serialized_bytes); + rust::Slice serialized_poly( + reinterpret_cast(serialized_bytes.data()), + serialized_bytes.size()); + + // Deserialize the bytes to an RnsPolynomial. + RnsPolynomialWrapper poly_wrapper = CreateEmptyRnsPolynomialWrapper(); + SECAGG_EXPECT_OK(UnwrapFfiStatus(DeserializeRnsPolynomialFromBytes( + serialized_poly, moduli_wrapper, &poly_wrapper))); + EXPECT_EQ(poly, *(poly_wrapper.ptr)); +} + +} // namespace +} // namespace secure_aggregation