Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions willow/benches/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ rust_library(
"//willow/src/shell:single_thread_hkdf",
"//willow/src/shell:vahe_shell",
"//willow/src/testing_utils",
"//willow/src/traits:ahe_traits",
"//willow/src/traits:client_traits",
"//willow/src/traits:decryptor_traits",
"//willow/src/traits:kahe_traits",
"//willow/src/traits:prng_traits",
"//willow/src/traits:server_traits",
"//willow/src/traits:vahe_traits",
"//willow/src/traits:verifier_traits",
"//willow/src/willow_v1:willow_v1_client",
"//willow/src/willow_v1:willow_v1_common",
Expand Down
22 changes: 12 additions & 10 deletions willow/benches/shell_benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ use std::collections::HashMap;
use std::hint::black_box;
use std::time::Duration;

use ahe_traits::AheBase;
use client_traits::SecureAggregationClient;
use decryptor_traits::SecureAggregationDecryptor;
use kahe_shell::ShellKahe;
use kahe_traits::KaheBase;
use parameters_shell::create_shell_configs;
use prng_traits::SecurePrng;
use server_traits::SecureAggregationServer;
use parameters_shell::create_shell_configs;
use single_thread_hkdf::SingleThreadHkdfPrng;
use testing_utils::{generate_random_unsigned_vector, ShellClient, ShellClientMessage};
use vahe_shell::ShellVahe;
use vahe_traits::VaheBase;
use verifier_traits::SecureAggregationVerifier;
use willow_api_common::AggregationConfig;
use willow_v1_client::WillowV1Client;
Expand All @@ -39,6 +41,7 @@ use willow_v1_server::{ServerState, WillowV1Server};
use willow_v1_verifier::{VerifierState, WillowV1Verifier};

const DEFAULT_ID: &str = "default";
const CONTEXT_STRING: &[u8] = b"benchmark_context_string";

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
Expand Down Expand Up @@ -126,21 +129,20 @@ fn setup_base(args: &Args) -> BaseInputs {
};
let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config).unwrap();
let public_kahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap();
let public_ahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap();

// Create client.
let common = WillowCommon {
kahe: ShellKahe::new(kahe_config.clone(), &public_kahe_seed).unwrap(),
vahe: ShellVahe::new(ahe_config.clone(), &public_ahe_seed).unwrap(),
kahe: ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap(),
vahe: ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap(),
};
let seed = SingleThreadHkdfPrng::generate_seed().unwrap();
let prng = SingleThreadHkdfPrng::create(&seed).unwrap();
let client = ShellClient { common, prng };

// Create decryptor, which needs its own `common` and `prng`.
let common = WillowCommon {
kahe: ShellKahe::new(kahe_config.clone(), &public_kahe_seed).unwrap(),
vahe: ShellVahe::new(ahe_config.clone(), &public_ahe_seed).unwrap(),
kahe: ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap(),
vahe: ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap(),
};
let seed = SingleThreadHkdfPrng::generate_seed().unwrap();
let prng = SingleThreadHkdfPrng::create(&seed).unwrap();
Expand All @@ -149,16 +151,16 @@ fn setup_base(args: &Args) -> BaseInputs {

// Create server.
let common = WillowCommon {
kahe: ShellKahe::new(kahe_config.clone(), &public_kahe_seed).unwrap(),
vahe: ShellVahe::new(ahe_config.clone(), &public_ahe_seed).unwrap(),
kahe: ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap(),
vahe: ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap(),
};
let server = WillowV1Server { common };
let mut server_state = ServerState::new();

// Create verifier.
let common = WillowCommon {
kahe: ShellKahe::new(kahe_config.clone(), &public_kahe_seed).unwrap(),
vahe: ShellVahe::new(ahe_config.clone(), &public_ahe_seed).unwrap(),
kahe: ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap(),
vahe: ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap(),
};
let verifier = WillowV1Verifier { common };
let verifier_state = VerifierState::new();
Expand Down
54 changes: 29 additions & 25 deletions willow/src/shell/ahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,6 @@ pub struct ShellAhe {
}

impl ShellAhe {
pub fn new(config: ShellAheConfig, public_seed: &Seed) -> Result<Self, status::StatusError> {
let num_coeffs = 1 << config.log_n;
let public_ahe_parameters = ahe::create_public_parameters(
config.log_n,
config.t,
&config.qs,
/* error_variance= */ ERROR_VARIANCE,
/* s_base_flood= */ S_BASE_FLOOD,
config.s_flood,
&public_seed,
)?;

Ok(Self { public_ahe_parameters, num_coeffs })
}

/// Convenience function.
fn add_vec_rns_polynomial_in_place(
&self,
Expand Down Expand Up @@ -400,6 +385,29 @@ impl AheBase for ShellAhe {

type Rng = SingleThreadHkdfPrng;

type Config = ShellAheConfig;

fn new(config: Self::Config, context_string: &[u8]) -> Result<Self, status::StatusError> {
let num_coeffs = 1 << config.log_n;
let public_seed = single_thread_hkdf::compute_hkdf(
context_string,
b"",
b"ShellAhe.public_seed",
single_thread_hkdf::seed_length(),
)?;
let public_ahe_parameters = ahe::create_public_parameters(
config.log_n,
config.t,
&config.qs,
/* error_variance= */ ERROR_VARIANCE,
/* s_base_flood= */ S_BASE_FLOOD,
config.s_flood,
&public_seed,
)?;

Ok(Self { public_ahe_parameters, num_coeffs })
}

fn aggregate_public_key_shares(
&self,
public_key_shares: &[Self::PublicKeyShare],
Expand Down Expand Up @@ -651,13 +659,13 @@ mod test {
const NUM_DECRYPTORS: usize = 3;
const NUM_CLIENTS: usize = 1000;
const MAX_ABSOLUTE_VALUE: i64 = 72;
const CONTEXT_STRING: &[u8] = b"test_context_string";

#[gtest]
fn test_encrypt_decrypt_one() -> googletest::Result<()> {
const NUM_VALUES: usize = 100;

let public_seed = SingleThreadHkdfPrng::generate_seed()?;
let ahe = ShellAhe::new(make_ahe_config(), &public_seed)?;
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_STRING)?;

let pt = vec![1, 2, 3, 4, 5, 6, 7, 8];
let seed = SingleThreadHkdfPrng::generate_seed()?;
Expand All @@ -682,8 +690,7 @@ mod test {
let config = make_ahe_config();
let t = config.t; // Keep a copy of the plaintext modulus.

let public_seed = SingleThreadHkdfPrng::generate_seed()?;
let ahe = ShellAhe::new(config, &public_seed)?;
let ahe = ShellAhe::new(config, CONTEXT_STRING)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;

Expand Down Expand Up @@ -750,8 +757,7 @@ mod test {

#[gtest]
fn test_errors() -> googletest::Result<()> {
let public_seed = SingleThreadHkdfPrng::generate_seed()?;
let ahe = ShellAhe::new(make_ahe_config(), &public_seed)?;
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_STRING)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;

Expand Down Expand Up @@ -826,11 +832,10 @@ mod test {

#[gtest]
fn test_manual_encryption() -> googletest::Result<()> {
let public_seed = SingleThreadHkdfPrng::generate_seed()?;
let config = make_ahe_config();
let q: i128 = config.qs.iter().map(|x| *x as i128).product();

let ahe = ShellAhe::new(config, &public_seed)?;
let ahe = ShellAhe::new(config, CONTEXT_STRING)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
let (_, pk_share, _) = ahe.key_gen(&mut prng)?;
Expand Down Expand Up @@ -871,9 +876,8 @@ mod test {

#[gtest]
fn test_export_ciphertext_has_right_order() -> googletest::Result<()> {
let public_seed = SingleThreadHkdfPrng::generate_seed()?;
let config = make_ahe_config();
let ahe = ShellAhe::new(config, &public_seed)?;
let ahe = ShellAhe::new(config, CONTEXT_STRING)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
let (_, pk_share, _) = ahe.key_gen(&mut prng)?;
Expand Down
60 changes: 32 additions & 28 deletions willow/src/shell/kahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,6 @@ pub struct ShellKahe {
}

impl ShellKahe {
pub fn new(
shell_kahe_config: ShellKaheConfig,
public_seed: &Seed,
) -> Result<Self, status::StatusError> {
Self::validate_kahe_config(&shell_kahe_config)?;
let num_coeffs = 1 << shell_kahe_config.log_n;
let public_kahe_parameters = kahe::create_public_parameters(
shell_kahe_config.log_n as u64,
shell_kahe_config.log_t as u64,
&shell_kahe_config.moduli,
shell_kahe_config.num_public_polynomials,
&public_seed,
)?;
Ok(Self { config: shell_kahe_config, num_coeffs, public_kahe_parameters })
}

/// Validates KAHE parameters in ShellKaheConfig.
fn validate_kahe_config(config: &ShellKaheConfig) -> Result<(), status::StatusError> {
if config.log_t > BIG_INT_BITS {
Expand Down Expand Up @@ -113,6 +97,30 @@ impl KaheBase for ShellKahe {

type Rng = SingleThreadHkdfPrng;

type Config = ShellKaheConfig;

fn new(
shell_kahe_config: Self::Config,
context_string: &[u8],
) -> Result<Self, status::StatusError> {
Self::validate_kahe_config(&shell_kahe_config)?;
let num_coeffs = 1 << shell_kahe_config.log_n;
let public_seed = single_thread_hkdf::compute_hkdf(
context_string,
b"",
b"ShellKahe.public_seed",
single_thread_hkdf::seed_length(),
)?;
let public_kahe_parameters = kahe::create_public_parameters(
shell_kahe_config.log_n as u64,
shell_kahe_config.log_t as u64,
&shell_kahe_config.moduli,
shell_kahe_config.num_public_polynomials,
&public_seed,
)?;
Ok(Self { config: shell_kahe_config, num_coeffs, public_kahe_parameters })
}

fn add_keys_in_place(
&self,
left: &Self::SecretKey,
Expand Down Expand Up @@ -299,6 +307,8 @@ mod test {
/// Default ID used in tests.
const DEFAULT_ID: &str = "default";

const CONTEXT_STRING: &[u8] = b"test_context_string";

#[gtest]
fn test_encrypt_decrypt_short() -> googletest::Result<()> {
let plaintext_modulus_bits = 39;
Expand All @@ -307,8 +317,7 @@ mod test {
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5 },
)]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
let public_seed = SingleThreadHkdfPrng::generate_seed()?;
let kahe = ShellKahe::new(kahe_config, &public_seed)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;

let pt = HashMap::from([(String::from(DEFAULT_ID), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]);
let seed = SingleThreadHkdfPrng::generate_seed()?;
Expand All @@ -327,8 +336,7 @@ mod test {
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5 },
)]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
let public_seed = SingleThreadHkdfPrng::generate_seed()?;
let kahe = ShellKahe::new(kahe_config, &public_seed)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;

let pt = HashMap::from([(String::from(DEFAULT_ID), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]);
let seed = SingleThreadHkdfPrng::generate_seed()?;
Expand Down Expand Up @@ -364,8 +372,7 @@ mod test {
packed_vector_config.num_packed_coeffs = num_messages;
set_kahe_num_public_polynomials(&mut kahe_config);

let public_seed = SingleThreadHkdfPrng::generate_seed()?;
let kahe = ShellKahe::new(kahe_config, &public_seed)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;

let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
Expand Down Expand Up @@ -397,8 +404,7 @@ mod test {
)]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;

let public_seed = SingleThreadHkdfPrng::generate_seed()?;
let kahe = ShellKahe::new(kahe_config, &public_seed)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;

Expand Down Expand Up @@ -434,8 +440,7 @@ mod test {
let packed_vector_configs = HashMap::from([]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;

let public_seed = SingleThreadHkdfPrng::generate_seed()?;
let kahe = ShellKahe::new(kahe_config, &public_seed)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;

Expand Down Expand Up @@ -477,8 +482,7 @@ mod test {
let plaintext_modulus_bits = 39;
let packed_vector_configs = HashMap::from([]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
let public_seed = SingleThreadHkdfPrng::generate_seed()?;
let kahe = ShellKahe::new(kahe_config, &public_seed)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;

// The seed used to sample the secret keys.
let seed = SingleThreadHkdfPrng::generate_seed()?;
Expand Down
Loading