Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion willow/benches/shell_benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ fn setup_base(args: &Args) -> BaseInputs {
max_number_of_decryptors: 1,
max_number_of_clients: args.max_num_clients as i64,
max_decryptor_dropouts: 0,
session_id: String::from("benchmark"),
key_id: b"benchmark".to_vec(),
};
let ahe_config = create_shell_ahe_config(aggregation_config.max_number_of_decryptors).unwrap();
let kahe_config = create_shell_kahe_config(&aggregation_config).unwrap();
Expand Down
3 changes: 2 additions & 1 deletion willow/proto/willow/aggregation_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ message AggregationConfigProto {
int64 max_number_of_decryptors = 5;
int64 max_decryptor_dropouts = 2;
int64 max_number_of_clients = 3;
string session_id = 4;
string session_id = 4 [deprecated = true];
bytes key_id = 6;
}

// The configuration for a single vector in an aggregation.
Expand Down
24 changes: 6 additions & 18 deletions willow/src/api/aggregation_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ use std::collections::HashMap;
/// aggregation failing.
/// max_number_of_clients: The maximum number of clients that will participate in the
/// aggregation.
/// session_id: The session id of the aggregation.
/// key_id: The key id of the aggregation, used as context_bytes to seed Kahe
/// and Vahe public parameters. Must be unique for each instantiation.
/// willow_version: The version of the willow protocol.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AggregationConfig {
pub vector_lengths_and_bounds: HashMap<String, (isize, i64)>,
pub max_number_of_decryptors: i64,
pub max_decryptor_dropouts: i64,
pub max_number_of_clients: i64,
pub session_id: String,
pub key_id: Vec<u8>,
}

impl FromProto for AggregationConfig {
Expand All @@ -57,7 +58,7 @@ impl FromProto for AggregationConfig {
max_number_of_decryptors: proto.max_number_of_decryptors(),
max_decryptor_dropouts: proto.max_decryptor_dropouts(),
max_number_of_clients: proto.max_number_of_clients(),
session_id: proto.session_id().to_string(),
key_id: proto.key_id().to_vec(),
})
}
}
Expand All @@ -71,7 +72,7 @@ impl ToProto for AggregationConfig {
max_number_of_decryptors: self.max_number_of_decryptors,
max_decryptor_dropouts: self.max_decryptor_dropouts,
max_number_of_clients: self.max_number_of_clients,
session_id: self.session_id.clone(),
key_id: self.key_id.clone(),
});
aggregation_config_proto.vector_configs_mut().copy_from(
self.vector_lengths_and_bounds.iter().map(|(key, (length, bound))| {
Expand All @@ -82,19 +83,6 @@ impl ToProto for AggregationConfig {
}
}

impl AggregationConfig {
/// Computes context bytes by hashing the session ID in the config.
pub fn compute_context_bytes(&self) -> Result<Vec<u8>, StatusError> {
let context_seed = single_thread_hkdf::compute_hkdf(
self.session_id.as_bytes(),
b"",
b"AggregationConfig.context_string",
single_thread_hkdf::seed_length(),
)?;
Ok(context_seed.as_bytes().to_vec())
}
}

#[cfg(test)]
mod tests {
use crate::AggregationConfig;
Expand All @@ -109,7 +97,7 @@ mod tests {
max_number_of_decryptors: 1,
max_decryptor_dropouts: 0,
max_number_of_clients: 1,
session_id: String::from("test"),
key_id: b"test".to_vec(),
};

verify_that!(
Expand Down
2 changes: 1 addition & 1 deletion willow/src/api/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl WillowShellClient {
})?;
let aggregation_config = AggregationConfig::from_proto(aggregation_config_proto, ())?;
let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?;
let context_bytes = aggregation_config.compute_context_bytes()?;
let context_bytes = &aggregation_config.key_id;
let kahe = ShellKahe::new(kahe_config, &context_bytes)?;
let vahe = ShellVahe::new(ahe_config, &context_bytes)?;
let client = WillowV1Client::new_with_randomly_generated_seed(kahe, vahe)?;
Expand Down
2 changes: 1 addition & 1 deletion willow/src/api/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ AggregationConfigProto CreateTestConfig() {
(*config.mutable_vector_configs())["metric1"] = vector_config;
config.set_max_number_of_decryptors(1);
config.set_max_number_of_clients(10);
config.set_session_id("test");
config.set_key_id("test");
return config;
}

Expand Down
14 changes: 7 additions & 7 deletions willow/src/api/server_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ pub struct ServerAccumulator {

impl ServerAccumulator {
fn new(aggregation_config: AggregationConfig) -> Result<Self, StatusError> {
let context_string = aggregation_config.compute_context_bytes()?;
let (kahe_config, vahe_config) = create_shell_configs(&aggregation_config)?;
let server_kahe = ShellKahe::new(kahe_config, &context_string)?;
let server_vahe = ShellVahe::new(vahe_config.clone(), &context_string)?;
let verifier_vahe = ShellVahe::new(vahe_config, &context_string)?;
let context_bytes = &aggregation_config.key_id;
let server_kahe = ShellKahe::new(kahe_config, context_bytes)?;
let server_vahe = ShellVahe::new(vahe_config.clone(), context_bytes)?;
let verifier_vahe = ShellVahe::new(vahe_config, context_bytes)?;
let server = WillowV1Server { kahe: server_kahe, vahe: server_vahe };
let verifier = WillowV1Verifier { vahe: verifier_vahe };
Ok(Self {
Expand Down Expand Up @@ -659,10 +659,10 @@ impl FinalResultDecryptor {

// Build server that holds the necessary KAHE and AHE contexts, and recover server state.
let aggregation_config = AggregationConfig::from_proto(aggregation_config_proto, ())?;
let context_string = aggregation_config.compute_context_bytes()?;
let (kahe_config, vahe_config) = create_shell_configs(&aggregation_config)?;
let kahe = ShellKahe::new(kahe_config, &context_string)?;
let vahe = ShellVahe::new(vahe_config, &context_string)?;
let context_bytes = &aggregation_config.key_id;
let kahe = ShellKahe::new(kahe_config, context_bytes)?;
let vahe = ShellVahe::new(vahe_config, context_bytes)?;
let server = WillowV1Server { kahe, vahe };
let server_state = ServerState::from_proto(server_state_proto, &server)?;

Expand Down
6 changes: 3 additions & 3 deletions willow/src/api/server_accumulator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ AggregationConfigProto CreateValidConfig() {
(*config.mutable_vector_configs())["test_vector"] = vector_config;
config.set_max_number_of_decryptors(1);
config.set_max_number_of_clients(10);
config.set_session_id("test_session");
config.set_key_id("test_key");
return config;
}

Expand All @@ -67,7 +67,7 @@ TEST(BasicServerAccumulatorTest, ToSerializedStateHasCorrectConfig) {
ASSERT_TRUE(state.ParseFromString(*serialized_state_or));
// Check if the config matches. We serialize and deserialize to compare protos
// easily or check fields.
EXPECT_EQ(state.aggregation_config().session_id(), config.session_id());
EXPECT_EQ(state.aggregation_config().key_id(), config.key_id());
EXPECT_EQ(state.aggregation_config().max_number_of_clients(),
config.max_number_of_clients());
}
Expand Down Expand Up @@ -382,7 +382,7 @@ TEST_F(ServerAccumulatorTest, MergeFailsWithOverlappingRanges) {

TEST_F(ServerAccumulatorTest, MergeFailsWithConfigMismatch) {
AggregationConfigProto config2 = config_;
config2.set_session_id("other_session");
config2.set_key_id("other_key");
SECAGG_ASSERT_OK_AND_ASSIGN(auto accumulator2,
ServerAccumulator::Create(config2));

Expand Down
18 changes: 9 additions & 9 deletions willow/src/shell/ahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,10 @@ impl AheBase for ShellAhe {

type Config = ShellAheConfig;

fn new(config: Self::Config, context_string: &[u8]) -> Result<Self, status::StatusError> {
fn new(config: Self::Config, context_bytes: &[u8]) -> Result<Self, status::StatusError> {
let num_coeffs = 1 << config.log_n;
let public_seed = single_thread_hkdf::compute_hkdf(
context_string,
context_bytes,
b"",
b"ShellAhe.public_seed",
single_thread_hkdf::seed_length(),
Expand Down Expand Up @@ -783,13 +783,13 @@ mod test {
const NUM_DECRYPTORS: usize = 3;
const NUM_CLIENTS: usize = 1000;
const MAX_ABSOLUTE_VALUE: i64 = 72;
const CONTEXT_STRING: &[u8] = b"test_context_string";
const CONTEXT_BYTES: &[u8] = b"test_context_bytes";

#[gtest]
fn test_encrypt_decrypt_one() -> googletest::Result<()> {
const NUM_VALUES: usize = 100;

let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_STRING)?;
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_BYTES)?;

let pt = vec![1, 2, 3, 4, 5, 6, 7, 8];
let seed = SingleThreadHkdfPrng::generate_seed()?;
Expand All @@ -811,7 +811,7 @@ mod test {
fn test_encrypt_decrypt_serialized() -> googletest::Result<()> {
const NUM_VALUES: usize = 100;

let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_STRING)?;
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_BYTES)?;

let pt = vec![1, 2, 3, 4, 5, 6, 7, 8];
let seed = SingleThreadHkdfPrng::generate_seed()?;
Expand Down Expand Up @@ -853,7 +853,7 @@ mod test {
let config = make_ahe_config();
let t = config.t; // Keep a copy of the plaintext modulus.

let ahe = ShellAhe::new(config, CONTEXT_STRING)?;
let ahe = ShellAhe::new(config, CONTEXT_BYTES)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;

Expand Down Expand Up @@ -920,7 +920,7 @@ mod test {

#[gtest]
fn test_errors() -> googletest::Result<()> {
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_STRING)?;
let ahe = ShellAhe::new(make_ahe_config(), CONTEXT_BYTES)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;

Expand Down Expand Up @@ -998,7 +998,7 @@ mod test {
let config = make_ahe_config();
let q: i128 = config.qs.iter().map(|x| *x as i128).product();

let ahe = ShellAhe::new(config, CONTEXT_STRING)?;
let ahe = ShellAhe::new(config, CONTEXT_BYTES)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
let (_, pk_share, _) = ahe.key_gen(&mut prng)?;
Expand Down Expand Up @@ -1040,7 +1040,7 @@ mod test {
#[gtest]
fn test_export_ciphertext_has_right_order() -> googletest::Result<()> {
let config = make_ahe_config();
let ahe = ShellAhe::new(config, CONTEXT_STRING)?;
let ahe = ShellAhe::new(config, CONTEXT_BYTES)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
let (_, pk_share, _) = ahe.key_gen(&mut prng)?;
Expand Down
22 changes: 11 additions & 11 deletions willow/src/shell/kahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,12 @@ impl KaheBase for ShellKahe {

fn new(
shell_kahe_config: Self::Config,
context_string: &[u8],
context_bytes: &[u8],
) -> Result<Self, status::StatusError> {
Self::validate_kahe_config(&shell_kahe_config)?;
let num_coeffs = 1 << shell_kahe_config.log_n;
let public_seed = single_thread_hkdf::compute_hkdf(
context_string,
context_bytes,
b"",
b"ShellKahe.public_seed",
single_thread_hkdf::seed_length(),
Expand Down Expand Up @@ -395,7 +395,7 @@ mod test {
/// Default ID used in tests.
const DEFAULT_ID: &str = "default";

const CONTEXT_STRING: &[u8] = b"test_context_string";
const CONTEXT_BYTES: &[u8] = b"test_context_bytes";

#[gtest]
fn test_encrypt_decrypt_short() -> googletest::Result<()> {
Expand All @@ -405,7 +405,7 @@ mod test {
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 },
)]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;

let pt = HashMap::from([(DEFAULT_ID.to_string(), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]);
let seed = SingleThreadHkdfPrng::generate_seed()?;
Expand All @@ -425,7 +425,7 @@ mod test {
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 8 },
)]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;

let pt = HashMap::from([(DEFAULT_ID.to_string(), vec![0, 1, 2, 3, 4, 5, 6, 7])]);
let seed = SingleThreadHkdfPrng::generate_seed()?;
Expand All @@ -445,7 +445,7 @@ mod test {
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 },
)]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;

let pt = HashMap::from([(DEFAULT_ID.to_string(), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]);
let seed = SingleThreadHkdfPrng::generate_seed()?;
Expand Down Expand Up @@ -484,7 +484,7 @@ mod test {
packed_vector_config.length = num_messages;
set_kahe_num_public_polynomials(&mut kahe_config);

let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;

let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;
Expand Down Expand Up @@ -518,7 +518,7 @@ mod test {
)]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;

let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;

Expand Down Expand Up @@ -556,7 +556,7 @@ mod test {
let packed_vector_configs = BTreeMap::from([]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;

let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;
let seed = SingleThreadHkdfPrng::generate_seed()?;
let mut prng = SingleThreadHkdfPrng::create(&seed)?;

Expand Down Expand Up @@ -600,7 +600,7 @@ mod test {
PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 },
)]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;

let pt = HashMap::from([(String::from(DEFAULT_ID), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]);
let seed = SingleThreadHkdfPrng::generate_seed()?;
Expand All @@ -626,7 +626,7 @@ mod test {
let plaintext_modulus_bits = 39;
let packed_vector_configs = BTreeMap::from([]);
let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?;
let kahe = ShellKahe::new(kahe_config, CONTEXT_BYTES)?;

// The seed used to sample the secret keys.
let seed = SingleThreadHkdfPrng::generate_seed()?;
Expand Down
10 changes: 5 additions & 5 deletions willow/src/shell/parameters_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ mod test {
max_number_of_decryptors: 1,
max_decryptor_dropouts: 0,
max_number_of_clients: 1,
session_id: String::from("test"),
key_id: b"test".to_vec(),
};
let invalid_plaintext_bits = 0;
let result = generate_packing_config(invalid_plaintext_bits, &agg_config);
Expand All @@ -130,7 +130,7 @@ mod test {
max_number_of_decryptors: 1,
max_decryptor_dropouts: 0,
max_number_of_clients: 1,
session_id: String::from("test"),
key_id: b"test".to_vec(),
};
let result = generate_packing_config(plaintext_bits, &bad_agg_config);
expect_true!(result.is_err());
Expand All @@ -151,7 +151,7 @@ mod test {
max_number_of_decryptors: 1,
max_decryptor_dropouts: 0,
max_number_of_clients: 0,
session_id: String::from("test"),
key_id: b"test".to_vec(),
};
let result = generate_packing_config(plaintext_bits, &bad_agg_config);
expect_true!(result.is_err());
Expand All @@ -168,7 +168,7 @@ mod test {
max_number_of_decryptors: 1,
max_decryptor_dropouts: 0,
max_number_of_clients: 2,
session_id: String::from("test"),
key_id: b"test".to_vec(),
};
let result = generate_packing_config(plaintext_bits, &agg_config);
expect_true!(result.is_err());
Expand All @@ -187,7 +187,7 @@ mod test {
max_number_of_decryptors: 1,
max_decryptor_dropouts: 0,
max_number_of_clients: 1 << 8,
session_id: String::from("test"),
key_id: b"test".to_vec(),
};
let plaintext_bits = 24;
let packed_vector_configs = generate_packing_config(plaintext_bits, &agg_config)?;
Expand Down
4 changes: 2 additions & 2 deletions willow/src/testing_utils/shell_testing_decryptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ namespace testing {
// encrypted messages can be decrypted properly.
class ShellTestingDecryptor {
public:
// Creates a new ShellTestingDecryptor from the given config, hashing the
// session ID from the config to seed KAHE and AHE public parameters.
// Creates a new ShellTestingDecryptor from the given config. The key_id from
// the config is used to seed KAHE and AHE public parameters.
static absl::StatusOr<std::unique_ptr<ShellTestingDecryptor>> Create(
const willow::AggregationConfigProto& aggregation_config);

Expand Down
Loading