diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index b3de917bf54..ff45ae51961 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -31,8 +31,8 @@ option(USE_OPEN_BLAS "Whether to use OpenBLAS (Defaults to compiling with Intel option(USE_INTEL_MKL "Whether to use Intel MKL (Defaults to compiling with Intel MKL)" OFF) # Build a shared libraray -set(HEADER_FILES libmatrixdnn.h libmatrixmult.h systemds.h common.h) -set(SOURCE_FILES libmatrixdnn.cpp libmatrixmult.cpp systemds.cpp) +set(HEADER_FILES libmatrixdnn.h libmatrixmult.h libhe.h systemds.h common.h) +set(SOURCE_FILES libmatrixdnn.cpp libmatrixmult.cpp libhe.cpp systemds.cpp) # Build a shared libraray add_library(systemds SHARED ${SOURCE_FILES} ${HEADER_FILES}) @@ -81,6 +81,9 @@ elseif(USE_INTEL_MKL) add_definitions(-DUSE_INTEL_MKL) endif() +find_package(SEAL 3.7 REQUIRED) +target_link_libraries(systemds SEAL::seal) + # Include directories. (added for Linux & Darwin, fix later for windows) # include paths can be spurious include_directories($ENV{JAVA_HOME}/include/) diff --git a/src/main/cpp/lib/libsystemds_mkl-Linux-x86_64.so b/src/main/cpp/lib/libsystemds_mkl-Linux-x86_64.so index a677b940138..ef98aa1542d 100644 Binary files a/src/main/cpp/lib/libsystemds_mkl-Linux-x86_64.so and b/src/main/cpp/lib/libsystemds_mkl-Linux-x86_64.so differ diff --git a/src/main/cpp/lib/libsystemds_openblas-Linux-x86_64.so b/src/main/cpp/lib/libsystemds_openblas-Linux-x86_64.so index 227443e3248..79ac01b7e66 100644 Binary files a/src/main/cpp/lib/libsystemds_openblas-Linux-x86_64.so and b/src/main/cpp/lib/libsystemds_openblas-Linux-x86_64.so differ diff --git a/src/main/cpp/libhe.cpp b/src/main/cpp/libhe.cpp new file mode 100644 index 00000000000..9bf16a87a12 --- /dev/null +++ b/src/main/cpp/libhe.cpp @@ -0,0 +1,279 @@ +// +// Created by me on 1/18/22. +// + +#include +#include +#include +#include + +#include "libhe.h" + +#include "seal/seal.h" +#include "seal/util/common.h" +#include "seal/util/rlwe.h" +#include "seal/util/polyarithsmallmod.h" + +using namespace std; +using namespace seal; + +RawPolynomData::RawPolynomData(const SEALContext& context) { + // Extract encryption parameters + auto &context_data = *context.key_context_data(); + auto &parms = context_data.parms(); + auto coeff_modulus = parms.coeff_modulus(); + size_t coeff_modulus_size = coeff_modulus.size(); + size_t coeff_count = parms.poly_modulus_degree(); + _size = util::mul_safe(coeff_count, coeff_modulus_size); +}; + +void RawPolynomData::set_data(vector& data) { + assert(data.size() == _size); + _data = move(data); +}; + + +gsl::span data_span(Ciphertext& c, size_t n) { + size_t poly_size = util::mul_safe(c.poly_modulus_degree(), c.coeff_modulus_size()); + return { c.data(n), poly_size }; +} + +RawPolynomData generate_a(const SEALContext& context) { + auto ciphertext_prng = UniformRandomGeneratorFactory::DefaultFactory()->create(); + + auto &context_data = *context.key_context_data(); + auto &parms = context_data.parms(); + + RawPolynomData rpd(parms); + vector a_poly_data(rpd.size()); + util::sample_poly_uniform(ciphertext_prng, parms, a_poly_data.data()); + rpd.set_data(a_poly_data); + return rpd; +} + +EncryptionParameters generateParameters() { + EncryptionParameters parms(scheme_type::ckks); + + size_t poly_modulus_degree = 8192; + parms.set_poly_modulus_degree(poly_modulus_degree); + parms.set_coeff_modulus(CoeffModulus::Create(poly_modulus_degree, { 60, 40, 40, 60 })); + return parms; +} + +size_t get_slot_count(const SEALContext& ctx) { + // slot count is only half of it. but every slot can take one complex number or 2 doubles. so in the end we get twice + // the space + return ctx.first_context_data()->parms().poly_modulus_degree(); +} + +// returns a vector filled with random double values between 0 and 1 +vector random_plaintext_data(size_t count) { + // this example is just copied from the CKKS example of SEAL + vector data; + data.reserve(count); + for (size_t i = 0; i < count; i++) + { + data.push_back(sqrt(static_cast(rand()) / RAND_MAX)); + } + return data; +} + +GlobalState::GlobalState(double _scale) : context(generateParameters()), a(generate_a(context)), scale(_scale) {}; + + +PublicKey Client::generate_partial_public_key(const SecretKey &secret_key, const SEALContext &context, RawPolynomData& a) +{ + PublicKey public_key; + Ciphertext& destination = public_key.data(); + + // We use a fresh memory pool with `clear_on_destruction' enabled. + MemoryPoolHandle pool = MemoryManager::GetPool(mm_prof_opt::mm_force_new, true); + + auto &context_data = *context.key_context_data(); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_modulus_size = coeff_modulus.size(); + size_t coeff_count = parms.poly_modulus_degree(); + auto ntt_tables = context_data.small_ntt_tables(); + size_t encrypted_size = 2; + + // If a polynomial is too small to store UniformRandomGeneratorInfo, + // it is best to just disable save_seed. Note that the size needed is + // the size of UniformRandomGeneratorInfo plus one (uint64_t) because + // of an indicator word that indicates a seeded ciphertext. + size_t poly_uint64_count = util::mul_safe(coeff_count, coeff_modulus_size); + + destination.resize(context, context.key_parms_id(), encrypted_size); + destination.is_ntt_form() = true; + destination.scale() = 1.0; + + // Create an instance of a random number generator. We use this for sampling + // a seed for a second PRNG used for sampling u (the seed can be public + // information. This PRNG is also used for sampling the noise/error below. + auto bootstrap_prng = parms.random_generator()->create(); + + // Sample a public seed for generating uniform randomness + prng_seed_type public_prng_seed; + bootstrap_prng->generate(prng_seed_byte_count, reinterpret_cast(public_prng_seed.data())); + + // Set up a new default PRNG for expanding u from the seed sampled above + auto ciphertext_prng = UniformRandomGeneratorFactory::DefaultFactory()->create(public_prng_seed); + + // Generate ciphertext: (c[0], c[1]) = ([-(as+e)]_q, a) + uint64_t *c0 = destination.data(); + uint64_t *c1 = destination.data(1); + + // copy a into c1 + assert(a.size() == poly_uint64_count); + copy(a.data(), a.data()+poly_uint64_count, c1); + + // Sample e <-- chi + auto noise(util::allocate_poly(coeff_count, coeff_modulus_size, pool)); + util::SEAL_NOISE_SAMPLER(bootstrap_prng, parms, noise.get()); + + // Calculate -(a*s + e) (mod q) and store in c[0] + for (size_t i = 0; i < coeff_modulus_size; i++) + { + util::dyadic_product_coeffmod( + secret_key.data().data() + i * coeff_count, c1 + i * coeff_count, coeff_count, coeff_modulus[i], + c0 + i * coeff_count); + + // Transform the noise e into NTT representation + ntt_negacyclic_harvey(noise.get() + i * coeff_count, ntt_tables[i]); + + util::add_poly_coeffmod( + noise.get() + i * coeff_count, c0 + i * coeff_count, coeff_count, coeff_modulus[i], + c0 + i * coeff_count); + util::negate_poly_coeffmod(c0 + i * coeff_count, coeff_count, coeff_modulus[i], c0 + i * coeff_count); + } + + public_key.parms_id() = context.key_parms_id(); + return public_key; +} + +Client::Client(GlobalState global_state) : _gs(move(global_state)), _encoder(_gs.context) { + KeyGenerator keygen(_gs.context); + _partial_secret_key = keygen.secret_key(); + _partial_public_key = generate_partial_public_key(_partial_secret_key, _gs.context, _gs.a); +}; + +Ciphertext Client::encrypted_data(gsl::span plain_data) { + if (!_encryptor) { + _encryptor = make_unique(_gs.context, *_public_key); + } + + // reinterpret plain data as complex + assert(plain_data.size() % 2 == 0); + gsl::span complex_plain_data(reinterpret_cast*>(plain_data.data()), plain_data.size() / 2); + + Plaintext plaintext; + encoder().encode(complex_plain_data, _gs.scale, plaintext); + Ciphertext ciphertext; + encryptor().encrypt(plaintext, ciphertext); + return ciphertext; +} + +Plaintext Client::partial_decryption(const Ciphertext& encrypted) { + using namespace seal::util; + + // c = (c0, c1) + // dec(c) = c0+c1*s + // we need: c0 + c1*sum(s[i]) + // so we return c1*s[i]*e[i] and add c0 at the server. e[i] is a noise term necessary for security + + // adapted from Decryptor::decrypt + + auto &context_data = *_gs.context.get_context_data(encrypted.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_modulus_size = coeff_modulus.size(); + size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_modulus_size); + + Plaintext plaintext; + // Since we overwrite destination, we zeroize destination parameters + // This is necessary, otherwise resize will throw an exception. + plaintext.parms_id() = parms_id_zero; + + // Resize destination to appropriate size + plaintext.resize(rns_poly_uint64_count); + + // Do the dot product of encrypted and the secret key array using NTT. + RNSIter destination(plaintext.data(), coeff_count); + ConstRNSIter secret_key_array(_partial_secret_key.data().data(), coeff_count); + ConstRNSIter c1(encrypted.data(1), coeff_count); + + SEAL_ITERATE( + iter(c1, secret_key_array, coeff_modulus, destination), coeff_modulus_size, [&](auto I) { + // put < c_1 * s > mod q in destination + dyadic_product_coeffmod(get<0>(I), get<1>(I), coeff_count, get<2>(I), get<3>(I)); + }); + + // for security we need to introduce noise here + // this part is based on rlwe.cpp:encrypt_zero_symmetric() + auto prng = parms.random_generator()->create(); + MemoryPoolHandle pool = MemoryManager::GetPool(mm_prof_opt::mm_force_new, true); + auto noise(allocate_poly(coeff_count, coeff_modulus_size, pool)); + SEAL_NOISE_SAMPLER(prng, parms, noise.get()); + auto ntt_tables = context_data.small_ntt_tables(); + + for (size_t i = 0; i < coeff_modulus_size; i++) + { + // Transform the noise e into NTT representation + ntt_negacyclic_harvey(noise.get() + i * coeff_count, ntt_tables[i]); + + add_poly_coeffmod( + noise.get() + i * coeff_count, plaintext.data() + i * coeff_count, coeff_count, coeff_modulus[i], + plaintext.data() + i * coeff_count); + } + + // Set destination parameters as in encrypted + plaintext.parms_id() = encrypted.parms_id(); + plaintext.scale() = encrypted.scale(); + return plaintext; +} + +Server::Server(GlobalState global_state) : _gs(move(global_state)) {}; + +void Server::accumulate_partial_public_keys(gsl::span partial_pub_keys) { + // sum only the first poly of the ciphertexts + // the second poly is always the same, see GlobalState.a + Ciphertext sum = sum_first_polys(context(), partial_pub_keys); + _public_key.data() = sum; + assert(is_valid_for(_public_key, context())); +} + +Ciphertext Server::sum_data(vector&& data) const { + Evaluator e(_gs.context); + Ciphertext result; + e.add_many(data, result); + return result; +} + +vector Server::average(const Ciphertext& encrypted_sum, gsl::span partial_decryptions) const { + // the partial decryptions were of the form c1*s[i]. we need c0 + sum(c1+s[i]) + // so we need to add c0 once here. + + // FIXME: this copies encrypted_sum, which is unnecessary + uint64_t num_coeffs = util::mul_safe(encrypted_sum.poly_modulus_degree(), encrypted_sum.coeff_modulus_size()); + gsl::span es_data(encrypted_sum.data(0), num_coeffs); + Plaintext c0(es_data); + c0.parms_id() = context().first_parms_id(); + c0.scale() = encrypted_sum.scale(); + + sum_first_polys_inplace(_gs.context, c0, partial_decryptions); // c0 + sum(c1+s[i]) + + // decode sum + size_t slot_count = context().first_context_data()->parms().poly_modulus_degree() >> 1; + CKKSEncoder encoder(context()); + vector result(slot_count * 2, 0.0); + gsl::span> result_destination(reinterpret_cast*>(result.data()), slot_count); + encoder.decode(c0, result_destination); + + // divide by N for average + for (double& x : result) { + x /= static_cast(partial_decryptions.size()); + } + return result; +} + diff --git a/src/main/cpp/libhe.h b/src/main/cpp/libhe.h new file mode 100644 index 00000000000..efea1d02131 --- /dev/null +++ b/src/main/cpp/libhe.h @@ -0,0 +1,129 @@ +// +// Created by me on 1/18/22. +// + +#ifndef LIBHE_H +#define LIBHE_H + +#include +#include +#include +#include + +#include "seal/seal.h" +#include "seal/util/common.h" +#include "seal/util/rlwe.h" +#include "seal/util/polyarithsmallmod.h" + +using namespace std; +using namespace seal; + +class RawPolynomData { + vector _data; + size_t _size; + +public: + explicit RawPolynomData(const SEALContext& context); + + SEAL_NODISCARD inline const size_t& size() const { return _size; }; + SEAL_NODISCARD inline uint64_t* data() { return _data.data(); }; + SEAL_NODISCARD inline gsl::span data_span() { return { data(), size() }; }; + + void set_data(vector& data); +}; + +gsl::span data_span(Ciphertext& c, size_t n); + +RawPolynomData generate_a(const SEALContext& context); + +EncryptionParameters generateParameters(); + +size_t get_slot_count(const SEALContext& ctx); + +// returns a vector filled with random double values between 0 and 1 +vector random_plaintext_data(size_t count); + +struct GlobalState { + SEALContext context; + RawPolynomData a; + double scale; + + explicit GlobalState(double _scale); +}; + +class Client { + GlobalState _gs; + CKKSEncoder _encoder; + SecretKey _partial_secret_key; + PublicKey _partial_public_key; + std::optional _public_key = std::nullopt; + std::unique_ptr _encryptor = nullptr; + + SEAL_NODISCARD static PublicKey generate_partial_public_key(const SecretKey &secret_key, const SEALContext &context, RawPolynomData& a); + +public: + explicit Client(GlobalState global_state); + + SEAL_NODISCARD inline const SEALContext& context() const { return _gs.context; }; + SEAL_NODISCARD inline const PublicKey& partial_public_key() const { return _partial_public_key; }; + SEAL_NODISCARD inline const CKKSEncoder& encoder() const { return _encoder; }; + SEAL_NODISCARD inline CKKSEncoder& encoder() { return _encoder; }; + SEAL_NODISCARD inline const Encryptor& encryptor() const { assert(_encryptor != nullptr); return *_encryptor; }; + SEAL_NODISCARD inline const PublicKey& public_key() { return *_public_key; }; + inline void set_public_key(const PublicKey& pk) { _public_key = make_optional(pk); }; + + Ciphertext encrypted_data(gsl::span plain_data); + + Plaintext partial_decryption(const Ciphertext& encrypted); +}; + +// adds b to a in place +template void sum_first_poly_inplace(const SEALContext& context, T& a, const T& b) { + auto &context_data = *context.get_context_data(a.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_modulus_size = coeff_modulus.size(); + + // by dereferencing we get only the first poly + auto summand_iter = *util::ConstPolyIter(b.data(), coeff_count, coeff_modulus_size); + auto sum_iter = *util::ConstPolyIter(a.data(), coeff_count, coeff_modulus_size); + auto result_iter = *util::PolyIter(a.data(), coeff_count, coeff_modulus_size); + // see Evaluator::add_inplace + util::add_poly_coeffmod(sum_iter, summand_iter, coeff_modulus_size, coeff_modulus, result_iter); +} + +// This function adds the first polys in summands to sum (either Ciphertext or Plaintext). +template T sum_first_polys_inplace(const SEALContext& context, T& sum, gsl::span summands) { + for (size_t i = 0; i < summands.size(); i++) { + sum_first_poly_inplace(context, sum, summands[i]); + } + return sum; +} + +// This function sums the first polys in summands (either Ciphertext or Plaintext). +template T sum_first_polys(const SEALContext& context, gsl::span summands) { + T sum = summands[0]; + sum_first_polys_inplace(context, sum, gsl::span(&summands.data()[1], summands.size() - 1)); + return sum; +} + +class Server { + GlobalState _gs; + PublicKey _public_key; + +public: + explicit Server(GlobalState global_state); + + SEAL_NODISCARD inline RawPolynomData& a() { return _gs.a; }; + SEAL_NODISCARD inline const SEALContext& context() const { return _gs.context; }; + SEAL_NODISCARD inline const PublicKey& public_key() const { return _public_key; }; + + void accumulate_partial_public_keys(gsl::span partial_pub_keys); + + Ciphertext sum_data(vector&& data) const; + + vector average(const Ciphertext& encrypted_sum, gsl::span partial_decryptions) const; +}; + +#endif //LIBHE_H diff --git a/src/main/cpp/systemds.cpp b/src/main/cpp/systemds.cpp index bed1d42f578..c957022d7b0 100644 --- a/src/main/cpp/systemds.cpp +++ b/src/main/cpp/systemds.cpp @@ -17,9 +17,16 @@ * under the License. */ +#ifdef _WIN32 +#include +#else +#include +#endif + #include "common.h" #include "libmatrixdnn.h" #include "libmatrixmult.h" +#include "libhe.h" #include "systemds.h" // Results from Matrix-vector/vector-matrix 1M x 1K, dense show that GetDoubleArrayElements creates a copy on OpenJDK. @@ -249,3 +256,270 @@ JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_conv2dBackwardF RELEASE_ARRAY(env, ret, retPtr, numThreads); return static_cast(nnz); } + +unique_ptr get_stream(JNIEnv* env, jbyteArray ary) { + size_t size = env->GetArrayLength(ary); + jbyte* data = env->GetByteArrayElements(ary, NULL); + + // FIXME: this copies string data TWICE. maybe implement a custom stream + // idea: implement a custom stream that wraps a jbyteArray, which calls ReleaseByteArrayElements in its d'tor + string data_s = string(reinterpret_cast(data), size); + unique_ptr ret = std::make_unique(data_s); + env->ReleaseByteArrayElements(ary, data, JNI_ABORT); + return ret; +} + +jbyteArray allocate_byte_array(JNIEnv* env, ostringstream& stream) { + string data = stream.str(); // FIXME: this copies string content. maybe implement custom ostream + jbyteArray ret = env->NewByteArray(data.size()); + env->SetByteArrayRegion(ret, 0, data.size(), reinterpret_cast(data.data())); + return ret; +} + +void my_assert(bool assertion, const char* message = "Assertion failed") { + if (!assertion) { + throw logic_error(message); + } +} + +template jbyteArray serialize(JNIEnv* env, T& object) { + ostringstream ss; + object.save(ss); + return allocate_byte_array(env, ss); +} + +void serialize_uint32_t(ostream& ss, uint32_t n) { + n = htonl(n); + ss.write(reinterpret_cast(&n), sizeof(n)); +} + +uint32_t deserialize_uint32_t(istream& ss) { + uint32_t ret; + ss.read(reinterpret_cast(&ret), sizeof(ret)); + ret = ntohl(ret); + return ret; +} + +Ciphertext deserialize_ciphertext(istream& ss, const SEALContext& context) { + Ciphertext ret; + ret.load(context, ss); + return ret; +} + +void serialize_plaintext(ostream& ss, Plaintext plaintext) { + plaintext.save(ss); +} + +template T deserialize_unsafe(JNIEnv* env, const SEALContext& context, jbyteArray serialized_object) { + auto ss = get_stream(env, serialized_object); + T deserialized; + deserialized.unsafe_load(context, *ss); // necessary bc partial public keys are not valid public keys + return deserialized; +} + +template T deserialize(JNIEnv* env, const SEALContext& context, jbyteArray serialized_object) { + auto ss = get_stream(env, serialized_object); + T deserialized; + deserialized.load(context, *ss); // necessary bc partial public keys are not valid public keys + return deserialized; +} + +JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_initClient + (JNIEnv* env, jclass, jbyteArray a_ary) { + double scale = pow(2.0, 40); + GlobalState gs(scale); + + // copy a to global state + size_t byte_size = env->GetArrayLength(a_ary); + my_assert(byte_size % sizeof(uint64_t) == 0); + size_t size = byte_size / sizeof(uint64_t); + uint64_t* a = reinterpret_cast(env->GetByteArrayElements(a_ary, NULL)); + gsl::span new_a(a, size); + + vector new_a_buf; + new_a_buf.assign(new_a.begin(), new_a.end()); + gs.a.set_data(new_a_buf); + + // release a without back-copy + env->ReleaseByteArrayElements(a_ary, reinterpret_cast(a), JNI_ABORT); + + Client* client = new Client(gs); + return reinterpret_cast(client); +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_generatePartialPublicKey + (JNIEnv* env, jclass, jlong client_ptr) { + Client* client = reinterpret_cast(client_ptr); + return serialize(env, client->partial_public_key().data()); +} + + +JNIEXPORT void JNICALL Java_org_apache_sysds_utils_NativeHelper_setPublicKey + (JNIEnv* env, jclass, jlong client_ptr, jbyteArray serialized_public_key) { + Client* client = reinterpret_cast(client_ptr); + client->set_public_key(deserialize(env, client->context(), serialized_public_key)); +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_encrypt + (JNIEnv* env, jclass, jlong client_ptr, jdoubleArray jdata) { + Client* client = reinterpret_cast(client_ptr); + size_t slot_count = get_slot_count(client->context()); + size_t num_data = env->GetArrayLength(jdata); + const double* data = static_cast(env->GetDoubleArrayElements(jdata, NULL)); + + std::ostringstream ss; + // write chunk size + uint32_t num_chunks = (num_data - 1) / slot_count + 1; + serialize_uint32_t(ss, num_chunks); + for (size_t i = 0; i < num_chunks; i++) { + size_t offset = slot_count * i; + size_t length = min(slot_count, num_data-offset); + gsl::span data_span(&data[offset], length); + Ciphertext encrypted_chunk = client->encrypted_data(data_span); + encrypted_chunk.save(ss); + } + return allocate_byte_array(env, ss); +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_partiallyDecrypt + (JNIEnv* env, jclass, jlong client_ptr, jbyteArray serialized_ciphertexts) { + Client* client = reinterpret_cast(client_ptr); + auto input = get_stream(env, serialized_ciphertexts); + std::ostringstream ss; + + // read num of chunks + uint32_t num_chunks = deserialize_uint32_t(*input); + + // write chunk size + serialize_uint32_t(ss, num_chunks); + for (int i = 0; i < num_chunks; i++) { + Ciphertext ciphertext = deserialize_ciphertext(*input, client->context()); + Plaintext plaintext = client->partial_decryption(ciphertext); + plaintext.save(ss); + } + + return allocate_byte_array(env, ss); +} + + +JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_initServer + (JNIEnv *, jclass) { + double scale = pow(2.0, 40); + GlobalState gs(scale); + Server* server = new Server(gs); + return reinterpret_cast(server); +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_generateA + (JNIEnv* env, jclass, jlong server_ptr) { + Server* server = reinterpret_cast(server_ptr); + uint64_t* data = server->a().data(); + size_t size = server->a().size() * sizeof(data[0]) / sizeof(jbyte); + jbyteArray ret = env->NewByteArray(size); + env->SetByteArrayRegion(ret, 0, size, reinterpret_cast(data)); + return ret; +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_aggregatePartialPublicKeys + (JNIEnv* env, jclass, jlong server_ptr, jobjectArray partial_public_keys_serialized) { + Server* server = reinterpret_cast(server_ptr); + size_t num_partial_public_keys = env->GetArrayLength(partial_public_keys_serialized); + std::vector partial_public_keys; + partial_public_keys.reserve(num_partial_public_keys); + + for (int i = 0; i < num_partial_public_keys; i++) { + jbyteArray j_data = static_cast(env->GetObjectArrayElement(partial_public_keys_serialized, i)); + partial_public_keys.push_back(deserialize_unsafe(env, server->context(), j_data)); + } + + server->accumulate_partial_public_keys(gsl::span(partial_public_keys)); + return serialize(env, server->public_key()); +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_accumulateCiphertexts + (JNIEnv* env, jclass, jlong server_ptr, jobjectArray ciphertexts_serialized) { + Server* server = reinterpret_cast(server_ptr); + size_t num_ciphertext_arys = env->GetArrayLength(ciphertexts_serialized); + + // init streams + vector> buf; + buf.reserve(num_ciphertext_arys); + for (int i = 0; i < num_ciphertext_arys; i++) { + jbyteArray j_data = static_cast(env->GetObjectArrayElement(ciphertexts_serialized, i)); + auto stream = get_stream(env, j_data); + buf.emplace_back(std::move(stream)); + } + + // read lengths of ciphertext arys and check that they are all the same + uint32_t num_slots = deserialize_uint32_t(*buf[0]); + for (int i = 1; i < num_ciphertext_arys; i++) { + my_assert(deserialize_uint32_t(*buf[i]) == num_slots); + } + + // read ciphertexts in chunks and accumulate them + ostringstream result; + serialize_uint32_t(result, num_slots); + for (int chunk_idx = 0; chunk_idx < num_slots; chunk_idx++) { + vector ciphertexts; + ciphertexts.reserve(num_ciphertext_arys); + for (int i = 0; i < num_ciphertext_arys; i++) { + Ciphertext deserialized; + deserialized.load(server->context(), *buf[i]); + ciphertexts.emplace_back(deserialized); + } + Ciphertext sum = server->sum_data(std::move(ciphertexts)); + sum.save(result); + } + + return allocate_byte_array(env, result); +} + + +JNIEXPORT jdoubleArray JNICALL Java_org_apache_sysds_utils_NativeHelper_average + (JNIEnv* env, jclass, jlong server_ptr, jbyteArray ciphertext_sum_serialized, jobjectArray partial_decryptions_serialized) { + Server* server = reinterpret_cast(server_ptr); + size_t slot_size = get_slot_count(server->context()); + size_t num_plaintext_arys = env->GetArrayLength(partial_decryptions_serialized); + + // init streams + vector> buf; + buf.reserve(num_plaintext_arys); + for (int i = 0; i < num_plaintext_arys; i++) { + jbyteArray j_data = static_cast(env->GetObjectArrayElement(partial_decryptions_serialized, i)); + auto stream = get_stream(env, j_data); + buf.emplace_back(std::move(stream)); + } + + // read lengths of ciphertext arys and check that they are all the same + uint32_t num_slots = deserialize_uint32_t(*buf[0]); + for (int i = 1; i < num_plaintext_arys; i++) { + my_assert(deserialize_uint32_t(*buf[i]) == num_slots, "number of plaintext slots is different"); + } + + auto encrypted_sum_stream = get_stream(env, ciphertext_sum_serialized); + my_assert(deserialize_uint32_t(*encrypted_sum_stream) == num_slots, "number of ciphertext slots is different"); + + // read ciphertexts in chunks and accumulate them + jdoubleArray result = env->NewDoubleArray(num_slots * slot_size); + for (int chunk_idx = 0; chunk_idx < num_slots; chunk_idx++) { + Ciphertext encrypted_sum = deserialize_ciphertext(*encrypted_sum_stream, server->context()); + + vector partial_decryptions; + partial_decryptions.reserve(num_plaintext_arys); + for (int i = 0; i < num_plaintext_arys; i++) { + Plaintext deserialized; + deserialized.load(server->context(), *buf[i]); + partial_decryptions.emplace_back(deserialized); + } + vector<double> averages = server->average(encrypted_sum, move(partial_decryptions)); + env->SetDoubleArrayRegion(result, chunk_idx*slot_size, averages.size(), averages.data()); + } + + return result; +} \ No newline at end of file diff --git a/src/main/cpp/systemds.h b/src/main/cpp/systemds.h index b4f741f2904..c268a406e1c 100644 --- a/src/main/cpp/systemds.h +++ b/src/main/cpp/systemds.h @@ -114,6 +114,86 @@ JNIEXPORT jboolean JNICALL Java_org_apache_sysds_utils_NativeHelper_conv2dSparse JNIEXPORT void JNICALL Java_org_apache_sysds_utils_NativeHelper_setMaxNumThreads (JNIEnv *, jclass, jint); +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: initClient + * Signature: ([B)J + */ +JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_initClient + (JNIEnv *, jclass, jbyteArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: generatePartialPublicKey + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_generatePartialPublicKey + (JNIEnv *, jclass, jlong); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: setPublicKey + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL Java_org_apache_sysds_utils_NativeHelper_setPublicKey + (JNIEnv *, jclass, jlong, jbyteArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: encrypt + * Signature: (J[D)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_encrypt + (JNIEnv *, jclass, jlong, jdoubleArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: partiallyDecrypt + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_partiallyDecrypt + (JNIEnv *, jclass, jlong, jbyteArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: initServer + * Signature: ()J + */ +JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_initServer + (JNIEnv *, jclass); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: generateA + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_generateA + (JNIEnv *, jclass, jlong); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: aggregatePartialPublicKeys + * Signature: (J[[B)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_aggregatePartialPublicKeys + (JNIEnv *, jclass, jlong, jobjectArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: accumulateCiphertexts + * Signature: (J[[B)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_utils_NativeHelper_accumulateCiphertexts + (JNIEnv *, jclass, jlong, jobjectArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: average + * Signature: (J[B[[B)[D + */ +JNIEXPORT jdoubleArray JNICALL Java_org_apache_sysds_utils_NativeHelper_average + (JNIEnv *, jclass, jlong, jbyteArray, jobjectArray); + #ifdef __cplusplus } #endif diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index b6ef98abc72..3dfad3413e5 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -44,7 +44,7 @@ public enum ExecType { CP, CP_FILE, SPARK, GPU, FED, INVALID } * Data types (tensor, matrix, scalar, frame, object, unknown). */ public enum DataType { - TENSOR, MATRIX, SCALAR, FRAME, LIST, UNKNOWN; + TENSOR, MATRIX, SCALAR, FRAME, LIST, ENCRYPTED_CIPHER, ENCRYPTED_PLAIN, UNKNOWN; public boolean isMatrix() { return this == MATRIX; diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java index 995a1e23309..dc9f5bba767 100644 --- a/src/main/java/org/apache/sysds/parser/Statement.java +++ b/src/main/java/org/apache/sysds/parser/Statement.java @@ -124,7 +124,6 @@ public enum PSCheckpointing { public static final String PS_FED_AGGREGATION_FNAME = "1701-NCC-aggregation_fname"; public static final String PS_FED_MODEL_VARID = "1701-NCC-model_varid"; - public abstract boolean controlStatement(); public abstract VariableSet variablesRead(); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java index e398abd32b0..3a1c78cb1d1 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java @@ -38,6 +38,7 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.controlprogram.caching.TensorObject; import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.cp.CPOperand; @@ -84,7 +85,9 @@ public class ExecutionContext { //parfor temporary functions (created by eval) protected Set<String> _fnNames; - + + protected SEALClient _seal_client; + /** * List of {@link GPUContext}s owned by this {@link ExecutionContext} */ @@ -152,6 +155,14 @@ public long getTID() { return _tid; } + public void setSealClient(SEALClient seal_client) { + _seal_client = seal_client; + } + + public SEALClient getSealClient() { + return _seal_client; + } + /** * Get the i-th GPUContext * @param index index of the GPUContext diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java index 004e35b5718..7dce506fbed 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java @@ -40,18 +40,13 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; -import org.apache.sysds.runtime.instructions.cp.BooleanObject; -import org.apache.sysds.runtime.instructions.cp.CPOperand; -import org.apache.sysds.runtime.instructions.cp.Data; -import org.apache.sysds.runtime.instructions.cp.DoubleObject; -import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction; -import org.apache.sysds.runtime.instructions.cp.IntObject; -import org.apache.sysds.runtime.instructions.cp.ListObject; -import org.apache.sysds.runtime.instructions.cp.StringObject; +import org.apache.sysds.runtime.instructions.cp.*; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; import org.apache.sysds.runtime.lineage.LineageItem; @@ -59,6 +54,7 @@ import org.apache.sysds.utils.stats.ParamServStatistics; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.concurrent.Callable; @@ -83,20 +79,23 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void> private final boolean _weighting; private double _weightingFactor = 1; private boolean _cycleStartAt0 = false; + private boolean _use_homomorphic_encryption = false; + private PublicKey _partial_public_key; public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, PSRuntimeBalancing runtimeBalancing, boolean weighting, int epochs, long batchSize, - int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps, int nbatches, boolean modelAvg) + int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps, int nbatches, boolean modelAvg, boolean use_homomorphic_encryption) { super(workerID, updFunc, freq, epochs, batchSize, ec, ps, nbatches, modelAvg); _numBatchesPerEpoch = numBatchesPerGlobalEpoch; _runtimeBalancing = runtimeBalancing; - _weighting = weighting; + _weighting = weighting && (!use_homomorphic_encryption); // FIXME: this disables weighting in favor of homomorphic encryption _numBatchesPerNbatch = nbatches; // generate the ID for the model _modelVarID = FederationUtils.getNextFedDataID(); - _modelAvg = modelAvg; + _modelAvg = _use_homomorphic_encryption || modelAvg; // we always have to use modelAvg when using homomorphic encryption + _use_homomorphic_encryption = use_homomorphic_encryption; } /** @@ -106,6 +105,9 @@ public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFreque */ public void setup(double weightingFactor) { incWorkerNumber(); + if (_use_homomorphic_encryption) { + ((HEParamServer)_ps).registerThread(_workerID, this); + } // prepare features and labels _featuresData = _features.getFedMapping().getFederatedData()[0]; @@ -160,21 +162,45 @@ public void setup(double weightingFactor) { PROG_END); // write program and meta data to worker - Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation( - new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), - new SetupFederatedWorker(_batchSize, dataSize, _possibleBatchesPerLocalEpoch, - programSerialized, _inst.getNamespace(), _inst.getFunctionName(), - _ps.getAggInst().getFunctionName(), _ec.getListObject("hyperparams"), - _modelVarID, _nbatches, _modelAvg))); + Future<FederatedResponse> udfResponse; + //_use_homomorphic_encryption = false; + if (_use_homomorphic_encryption) { + // TODO: generate a here + byte[] a = ((HEParamServer)_ps).generateA(); + // generate pk[i] on each client and return it + udfResponse = _featuresData.executeFederatedOperation( + new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), + new SetupHEFederatedWorker(_batchSize, dataSize, _possibleBatchesPerLocalEpoch, + programSerialized, _inst.getNamespace(), _inst.getFunctionName(), + _ps.getAggInst().getFunctionName(), _ec.getListObject("hyperparams"), + _modelVarID, _nbatches, a))); + } else { + udfResponse = _featuresData.executeFederatedOperation( + new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), + new SetupFederatedWorker(_batchSize, dataSize, _possibleBatchesPerLocalEpoch, + programSerialized, _inst.getNamespace(), _inst.getFunctionName(), + _ps.getAggInst().getFunctionName(), _ec.getListObject("hyperparams"), + _modelVarID, _nbatches, _modelAvg))); + } + FederatedResponse response; try { - FederatedResponse response = udfResponse.get(); + response = udfResponse.get(); if(!response.isSuccessful()) throw new DMLRuntimeException("FederatedLocalPSThread: Setup UDF failed"); + } catch(Exception e) { throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Setup UDF" + e.getMessage()); } + if (_use_homomorphic_encryption) { + try { + _partial_public_key = (PublicKey) response.getData()[0]; + } + catch (Exception e) { + throw new DMLRuntimeException("FederatedLocalPSThread: HE Setup UDF didn't return an object"); + } + } } /** @@ -196,7 +222,7 @@ public void teardown() { throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Teardown UDF" + e.getMessage()); } } - + /** * Setup UDF executed on the federated worker */ @@ -258,9 +284,61 @@ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { } } - /** + private static class SetupHEFederatedWorker extends SetupFederatedWorker { + private static final long serialVersionUID = 9128347291804980123L; + + byte[] _partial_pubkey_a; + + protected SetupHEFederatedWorker(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch, + String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName, + ListObject hyperParams, long modelVarID, int nbatches, byte[] partial_pubkey_a) { + // delegate everything to parent class. set modelAvg to true, as it is the only supported case + super(batchSize, dataSize, possibleBatchesPerLocalEpoch, programString, namespace, gradientsFunctionName, + aggregationFunctionName, hyperParams, modelVarID, nbatches, true); + + _partial_pubkey_a = partial_pubkey_a; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + // TODO: set other CKKS parameters + // TODO generate partial public key + SEALClient sc = new SEALClient(_partial_pubkey_a); + ec.setSealClient(sc); + PublicKey partial_pubkey = sc.generatePartialPublicKey(); + + FederatedResponse res = super.execute(ec, data); + if (!res.isSuccessful()) { + return res; + } + + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, partial_pubkey); + } + } + /** * Teardown UDF executed on the federated worker */ + private static class SetPublicKeyFederatedWorker extends FederatedUDF { + private static final long serialVersionUID = -1536502123123318969L; + private final PublicKey _public_key; + + protected SetPublicKeyFederatedWorker(PublicKey public_key) { + super(new long[]{}); + _public_key = public_key; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + ec.getSealClient().setPublicKey(_public_key); + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); + } + + @Override + public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { + return null; + } + } + private static class TeardownFederatedWorker extends FederatedUDF { private static final long serialVersionUID = -153650281873318969L; @@ -324,6 +402,7 @@ protected ListObject pullModel() { } protected void weightAndPushGradients(ListObject gradients) { + assert (!(_weighting && _use_homomorphic_encryption)) : "weights and homomorphic encryption are not supported together"; // scale gradients - must only include MatrixObjects if(_weighting && _weightingFactor != 1) { Timing tWeighting = DMLScript.STATISTICS ? new Timing(true) : null; @@ -354,11 +433,12 @@ protected void computeWithBatchUpdates() { int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch); ListObject model = pullModel(); ListObject gradients = computeGradientsForNBatches(model, 1, localStartBatchNum); - if (_modelAvg) + if (_modelAvg && !_use_homomorphic_encryption) + // we can't call the agg fn if we use HE, because it is implemented homomorphically in SEALServer::aggregateCiphertexts model = _ps.updateLocalModel(_ec, gradients, model); else ParamservUtils.cleanupListObject(model); - weightAndPushGradients(_modelAvg ? model : gradients); + weightAndPushGradients((_modelAvg && !_use_homomorphic_encryption) ? model : gradients); ParamservUtils.cleanupListObject(gradients); } } @@ -431,11 +511,16 @@ protected ListObject computeGradientsForNBatches(ListObject model, } // create and execute the udf on the remote worker + Object udf; + if (_use_homomorphic_encryption) { + udf = new HEComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID()}, + new long[]{_modelVarID}, numBatchesToCompute, localUpdate, localStartBatchNum); + } else { + udf = new federatedComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), + _modelVarID}, numBatchesToCompute, localUpdate, localStartBatchNum); + } Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation( - new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), - new federatedComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), - _modelVarID}, numBatchesToCompute, localUpdate, localStartBatchNum) - )); + new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), udf)); try { Object[] responseData = udfResponse.get().getData(); @@ -492,12 +577,12 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { ArrayList<DataIdentifier> inputs = func.getInputParams(); ArrayList<DataIdentifier> outputs = func.getOutputParams(); CPOperand[] boundInputs = inputs.stream() - .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) - .toArray(CPOperand[]::new); + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); + .collect(Collectors.toCollection(ArrayList::new)); Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunc, - opt, boundInputs, func.getInputParamNames(), outputNames, "gradient function"); + opt, boundInputs, func.getInputParamNames(), outputNames, "gradient function"); DataIdentifier gradientsOutput = outputs.get(0); // recreate aggregation instruction and output if needed @@ -508,12 +593,12 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { inputs = func.getInputParams(); outputs = func.getOutputParams(); boundInputs = inputs.stream() - .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) - .toArray(CPOperand[]::new); + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); outputNames = outputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); + .collect(Collectors.toCollection(ArrayList::new)); aggregationInstruction = new FunctionCallCPInstruction(namespace, aggFunc, - opt, boundInputs, func.getInputParamNames(), outputNames, "aggregation function"); + opt, boundInputs, func.getInputParamNames(), outputNames, "aggregation function"); aggregationOutput = outputs.get(0); } ListObject accGradients = null; @@ -540,7 +625,7 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // accrue the computed gradients - In the single batch case this is just a list copy // is this equivalent for momentum based and AMS prob? accGradients = modelAvg ? null : - ParamservUtils.accrueGradients(accGradients, gradients, false); + ParamservUtils.accrueGradients(accGradients, gradients, false); // update the local model with gradients if needed // FIXME ensure that with modelAvg we always update the model @@ -564,11 +649,82 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // model clean up ParamservUtils.cleanupListObject(ec, ec.getVariable(Statement.PS_FED_MODEL_VARID).toString()); // TODO double check cleanup gradients and models - + // stop timing DoubleObject gradientsTime = new DoubleObject(tGradients.stop()); return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, - new Object[]{modelAvg ? model : accGradients, gradientsTime}); + new Object[]{modelAvg ? model : accGradients, gradientsTime}); + } + + @Override + public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { + return null; + } + } + + + /** + * This wraps federatedComputeGradientsForNBatches and adds encryption + */ + private static class HEComputeGradientsForNBatches extends federatedComputeGradientsForNBatches { + private static final long serialVersionUID = -3535901512348794852L; + private final long[] _deferredIds; + + protected HEComputeGradientsForNBatches(long[] deferredIds, long[] inIDs, int numBatchesToCompute, boolean localUpdate, int localStartBatchNum) { + super(inIDs, numBatchesToCompute, localUpdate, localStartBatchNum); + _deferredIds = deferredIds; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data_without_deferred) { + Timing tGradients = new Timing(true); + // add features and gradients to data + Data[] deferred_inputs = Arrays.stream(_deferredIds).mapToObj(id -> ec.getVariable(String.valueOf(id))).toArray(Data[]::new); + Data[] data = Arrays.copyOf(deferred_inputs, deferred_inputs.length + data_without_deferred.length); + System.arraycopy(data_without_deferred, 0, data, deferred_inputs.length, data_without_deferred.length); + FederatedResponse res = super.execute(ec, data); + + if (!res.isSuccessful()) { + return res; + } + + // encrypt model with SEAL + try { + ListObject model = (ListObject) res.getData()[0]; + ListObject encrypted_model = new ListObject(model); + for (int matrix_idx = 0; matrix_idx < model.getLength(); matrix_idx++) { + CiphertextMatrix encrypted_matrix = ec.getSealClient().encrypt((MatrixObject) model.getData(matrix_idx)); + encrypted_model.set(matrix_idx, encrypted_matrix); + } + // overwrite model with encryption + res.getData()[0] = encrypted_model; + + // stop timing + DoubleObject gradientsTime = new DoubleObject(tGradients.stop()); + res.getData()[1] = gradientsTime; + } catch (Exception e) { + return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new Object[] { e }); + } + return res; + } + } + + private static class HEComputePartialDecryption extends FederatedUDF { + private static final long serialVersionUID = -4535098129348794852L; + private final CiphertextMatrix[] _encrypted_sum; + + protected HEComputePartialDecryption(CiphertextMatrix[] encrypted_sum) { + super(new long[]{}); + _encrypted_sum = encrypted_sum; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + PlaintextMatrix[] result = new PlaintextMatrix[_encrypted_sum.length]; + for (int i = 0; i < result.length; i++) { + result[i] = ec.getSealClient().partiallyDecrypt(_encrypted_sum[i]); + } + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, result); } @Override @@ -577,6 +733,20 @@ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { } } + + public PlaintextMatrix[] getPartialDecryption(CiphertextMatrix[] encrypted_sum) { + Object udf = new HEComputePartialDecryption(encrypted_sum); + Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation( + new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), udf)); + + try { + Object[] responseData = udfResponse.get().getData(); + return (PlaintextMatrix[]) responseData; + } catch(Exception e) { + throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute UDF" + e.getMessage()); + } + } + // Statistics methods protected void accFedPSGradientWeightingTime(Timing time) { if (DMLScript.STATISTICS && time != null) @@ -608,4 +778,24 @@ protected void accBatchIndexingTime(Timing time) { protected void accGradientComputeTime(Timing time) { throw new NotImplementedException(); } + + public PublicKey getPartialPublicKey() { + return _partial_public_key; + } + + public void setPublicKey(PublicKey public_key) { + Future<FederatedResponse> res = _featuresData.executeFederatedOperation( + new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), + new SetPublicKeyFederatedWorker(public_key))); + + try { + FederatedResponse response = res.get(); + if(!response.isSuccessful()) + throw new DMLRuntimeException("FederatedLocalPSThread: SetPublicKey UDF failed"); + + } + catch(Exception e) { + throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Public Key Setup UDF" + e.getMessage()); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java new file mode 100644 index 00000000000..145b1e932d5 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java @@ -0,0 +1,136 @@ +package org.apache.sysds.runtime.controlprogram.paramserv; + +import org.apache.sysds.parser.Statement; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALServer; +import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix; +import org.apache.sysds.runtime.instructions.cp.ListObject; +import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +/// This class implements Homomorphic Encryption (HE) for LocalParamServer. It only supports modelAvg=true. +public class HEParamServer extends LocalParamServer { + private int _thread_counter = 0; + private final List<FederatedPSControlThread> _threads; + private final List<Object> _result_buffer; // one per thread + private Object _result; + private final SEALServer _seal_server; + + public static HEParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType, + Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch, + MatrixObject valFeatures, MatrixObject valLabels, int nbatches) + { + return new HEParamServer(model, aggFunc, updateType, freq, ec, + workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches); + } + + private HEParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, + Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch, + MatrixObject valFeatures, MatrixObject valLabels, int nbatches) + { + super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, true); + + _seal_server = new SEALServer(); + + _threads = Collections.synchronizedList(new ArrayList<>(workerNum)); + for (int i = 0; i < getNumWorkers(); i++) { + _threads.add(null); + } + + _result_buffer = new ArrayList<>(workerNum); + resetResultBuffer(); + } + + public void registerThread(int thread_id, FederatedPSControlThread thread) { + _threads.set(thread_id, thread); + } + + private synchronized void resetResultBuffer() { + _result_buffer.clear(); + for (int i = 0; i < getNumWorkers(); i++) { + _result_buffer.add(null); + } + } + + public byte[] generateA() { + return _seal_server.generateA(); + } + + public PublicKey aggregatePartialPublicKeys(PublicKey[] partial_public_keys) { + return _seal_server.aggregatePartialPublicKeys(partial_public_keys); + } + + // this method collects all T Objects from each worker into a list and then calls f once on this list to produce + // another T, which it returns. + private synchronized <T,U> U collectAndDo(int workerId, T obj, Function<List<T>, U> f) { + _result_buffer.set(workerId, obj); + _thread_counter++; + + if (_thread_counter == getNumWorkers()) { + List<T> buf = _result_buffer.stream().map(x -> (T)x).collect(Collectors.toList()); + _result = f.apply(buf); + resetResultBuffer(); + _thread_counter = 0; + notifyAll(); + } else { + try { + wait(); + } catch (InterruptedException i) { + throw new RuntimeException("thread interrupted"); + } + } + + return (U) _result; + } + + private CiphertextMatrix[] homomorphicAggregation(List<ListObject> encrypted_models) { + CiphertextMatrix[] result = new CiphertextMatrix[encrypted_models.get(0).getLength()]; + for (int matrix_idx = 0; matrix_idx < encrypted_models.get(0).getLength(); matrix_idx++) { + CiphertextMatrix[] summands = new CiphertextMatrix[encrypted_models.size()]; + for (int i = 0; i < encrypted_models.size(); i++) { + summands[i] = (CiphertextMatrix) encrypted_models.get(i).getData(matrix_idx); + } + result[matrix_idx] = _seal_server.accumulateCiphertexts(summands);; + } + return result; + } + + private Void homomorphicAverage(CiphertextMatrix[] encrypted_sums, List<PlaintextMatrix[]> partial_decryptions) { + MatrixObject[] result = new MatrixObject[partial_decryptions.get(0).length]; + for (int matrix_idx = 0; matrix_idx < partial_decryptions.get(0).length; matrix_idx++) { + PlaintextMatrix[] partial_plaintexts = new PlaintextMatrix[partial_decryptions.size()]; + for (int i = 0; i < partial_decryptions.size(); i++) { + partial_plaintexts[i] = partial_decryptions.get(i)[matrix_idx]; + } + + result[matrix_idx] = _seal_server.average(encrypted_sums[matrix_idx], partial_plaintexts); + } + + ListObject old_model = getResult(); + ListObject new_model = new ListObject(old_model); + for (int i = 0; i < new_model.getLength(); i++) { + new_model.set(i, result[i]); + } + updateAndBroadcastModel(new_model, null); + return null; + } + + @Override + public void push(int workerID, ListObject encrypted_model) { + // wait for all updates and sum them homomorphically + CiphertextMatrix[] homomorphic_sum = collectAndDo(workerID, encrypted_model, this::homomorphicAggregation); + + // get partial decryptions + PlaintextMatrix[] partial_decryption = _threads.get(workerID).getPartialDecryption(homomorphic_sum); + + // do average and update global model + collectAndDo(workerID, partial_decryption, x -> this.homomorphicAverage(homomorphic_sum, x)); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java index 50c76a0f427..9fd49ca0d10 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java @@ -39,7 +39,7 @@ public static LocalParamServer create(ListObject model, String aggFunc, Statemen workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg); } - private LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, + protected LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg) { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java index 009dc20a338..0e09fabf30b 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java @@ -19,10 +19,7 @@ package org.apache.sysds.runtime.controlprogram.paramserv; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; +import java.util.*; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.stream.Collectors; @@ -326,30 +323,8 @@ protected synchronized void updateAverageModel(int workerID, ListObject model) { _accModels = ParamservUtils.accrueGradients(_accModels, weightParams, true); if(allFinished()) { - _model = setParams(_ec, _accModels, _model); - if (DMLScript.STATISTICS && tAgg != null) - ParamServStatistics.accAggregationTime((long) tAgg.stop()); - _accModels = null; //reset for next accumulation - - // This if has grown to be quite complex its function is rather simple. Validate at the end of each epoch - // In the BSP batch case that occurs after the sync counter reaches the number of batches and in the - // BSP epoch case every time - if(_numBatchesPerEpoch != -1 && (_freq == Statement.PSFrequency.EPOCH || (_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) { - - if(LOG.isInfoEnabled()) - LOG.info("[+] PARAMSERV: completed EPOCH " + _epochCounter); - time_epoch(); - if(_validationPossible) { - validate(); - } - _epochCounter++; - _syncCounter = 0; - } - // Broadcast the updated model + updateAndBroadcastModel(_accModels, tAgg); resetFinishedStates(); - broadcastModel(true); - if(LOG.isDebugEnabled()) - LOG.debug("Global parameter is broadcasted successfully "); } break; } @@ -365,7 +340,33 @@ protected synchronized void updateAverageModel(int workerID, ListObject model) { } } - protected ListObject weightModels(ListObject params, int numWorkers) { + protected void updateAndBroadcastModel(ListObject new_model, Timing tAgg) { + _model = setParams(_ec, new_model, _model); + if (DMLScript.STATISTICS && tAgg != null) + ParamServStatistics.accAggregationTime((long) tAgg.stop()); + _accModels = null; //reset for next accumulation + + // This if has grown to be quite complex its function is rather simple. Validate at the end of each epoch + // In the BSP batch case that occurs after the sync counter reaches the number of batches and in the + // BSP epoch case every time + if(_numBatchesPerEpoch != -1 && (_freq == Statement.PSFrequency.EPOCH || (_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) { + + if(LOG.isInfoEnabled()) + LOG.info("[+] PARAMSERV: completed EPOCH " + _epochCounter); + time_epoch(); + if(_validationPossible) { + validate(); + } + _epochCounter++; + _syncCounter = 0; + } + // Broadcast the updated model + broadcastModel(true); + if(LOG.isDebugEnabled()) + LOG.debug("Global parameter is broadcasted successfully "); + } + + protected ListObject weightModels(ListObject params, int numWorkers) { double _averagingFactor = 1d / numWorkers; if( _averagingFactor != 1) { @@ -472,6 +473,10 @@ private void validate() { ParamServStatistics.accValidationTime((long) tValidate.stop()); } + public int getNumWorkers() { + return _numWorkers; + } + public FunctionCallCPInstruction getAggInst() { return _inst; } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java index 5bb3e12dcab..96979e3a5d8 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java @@ -87,6 +87,7 @@ static List<MatrixObject> sliceFederatedMatrix(MatrixObject fedMatrix) { new MatrixCharacteristics(range.getSize(0), range.getSize(1)), Types.FileFormat.BINARY) ); + slice.setPrivacyConstraints(fedMatrix.getPrivacyConstraint()); // Create new federation map List<Pair<FederatedRange, FederatedData>> newFedHashMap = new ArrayList<>(); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/PublicKey.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/PublicKey.java new file mode 100644 index 00000000000..c1bb6749611 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/PublicKey.java @@ -0,0 +1,17 @@ +package org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption; + +import java.io.Serializable; + +public class PublicKey implements Serializable { + private static final long serialVersionUID = 91289081237980123L; + + private final byte[] _data; + + public PublicKey(byte[] data) { + _data = data; + } + + public byte[] getData() { + return _data; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java new file mode 100644 index 00000000000..86e7b8e3fc3 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java @@ -0,0 +1,55 @@ +package org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption; + +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix; +import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.utils.NativeHelper; + +import java.util.stream.IntStream; + +public class SEALClient { + public SEALClient(byte[] a) { + // TODO take params here, like slot_count etc. + // TODO init ctx and block_size + ctx = NativeHelper.initClient(a); + } + + // this is a pointer to the context used by all native methods of this class + private final long ctx; + + + // generates a partial public key and returns it + // stores a partial private key corresponding to the partial public key in ctx + public PublicKey generatePartialPublicKey() { + return new PublicKey(NativeHelper.generatePartialPublicKey(ctx)); + } + + // sets the public key and stores it in ctx + public void setPublicKey(PublicKey public_key) { + NativeHelper.setPublicKey(ctx, public_key.getData()); + } + + // encrypts one block of data with public key stored statically and returns it + // setPublicKey() must have been called before calling this + // half_block is half the size of SEAL slot_count + public CiphertextMatrix encrypt(MatrixObject plaintext) { + MatrixBlock mb = plaintext.acquireReadAndRelease(); + if (mb.isInSparseFormat()) { + mb.allocateSparseRowsBlock(); + mb.sparseToDense(); + } + DenseBlock db = mb.getDenseBlock(); + int[] dims = IntStream.range(0, db.numDims()).map(db::getDim).toArray(); + double[] raw_data = mb.getDenseBlockValues(); + return new CiphertextMatrix(dims, plaintext.getDataCharacteristics(), NativeHelper.encrypt(ctx, raw_data)); + } + + // partially decrypts one block with the partial private key. generatePartialPublicKey() must + // have been called before calling this function + // returns a block half the size of SEAL slot_count + public PlaintextMatrix partiallyDecrypt(CiphertextMatrix ciphertext) { + return new PlaintextMatrix(ciphertext.getDims(), ciphertext.getDataCharacteristics(), NativeHelper.partiallyDecrypt(ctx, ciphertext.getData())); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALServer.java new file mode 100644 index 00000000000..a7b58b2fc2f --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALServer.java @@ -0,0 +1,84 @@ +package org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption; + +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.DenseBlockFactory; +import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix; +import org.apache.sysds.runtime.instructions.cp.Encrypted; +import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataFormat; +import org.apache.sysds.utils.NativeHelper; + +import java.util.Arrays; + +public class SEALServer { + public SEALServer() { + // TODO take params here, like slot_count etc. + ctx = NativeHelper.initServer(); + } + + // this is a pointer to the context used by all native methods of this class + private final long ctx; + private byte[] _a; + + // NOTICE: all long[] arys here have to be of size SEAL slot_count + // they represent the data of one Ciphertext object + // NOTICE: all double[] arys here have to be half the size of SEAL slot_count + // they represent the data of one Plaintext object + + // this generates the a constant. in a future version we want to generate this together with the clients to prevent misuse + public synchronized byte[] generateA() { + if (_a == null) { + _a = NativeHelper.generateA(ctx); + } + return _a; + } + + // accumulates the given partial public keys into a public key, stores it in ctx and returns it + public PublicKey aggregatePartialPublicKeys(PublicKey[] partial_public_keys) { + return new PublicKey(NativeHelper.aggregatePartialPublicKeys(ctx, extractRawData(partial_public_keys))); + } + + // accumulates the given ciphertext blocks into a sum ciphertext and returns it + // stores c0 of the sum to be used in averageBlocks() + public CiphertextMatrix accumulateCiphertexts(CiphertextMatrix[] ciphertexts) { + return new CiphertextMatrix(ciphertexts[0].getDims(), ciphertexts[0].getDataCharacteristics(), NativeHelper.accumulateCiphertexts(ctx, extractRawData(ciphertexts))); + } + + // averages the partial decryptions and stores the result in old_mo + // encrypted_sum is the result of accumulateCiphertexts() and partial_plaintexts is the result of partiallyDecryptBlock + // of each ciphertext fed into accumulateCiphertexts + public MatrixObject average(CiphertextMatrix encrypted_sum, PlaintextMatrix[] partial_plaintexts) { + double[] raw_result = NativeHelper.average(ctx, encrypted_sum.getData(), extractRawData(partial_plaintexts)); + int[] dims = encrypted_sum.getDims(); + int result_len = Arrays.stream(dims).reduce(1, (x,y) -> x*y); + DataCharacteristics dc = encrypted_sum.getDataCharacteristics(); + + DenseBlock new_dense_block = DenseBlockFactory.createDenseBlock(Arrays.copyOf(raw_result, result_len), dims); + MatrixBlock new_matrix_block = new MatrixBlock((int)dc.getRows(), (int)dc.getCols(), new_dense_block); + MatrixObject new_mo = new MatrixObject(Types.ValueType.FP64, OptimizerUtils.getUniqueTempFileName(), new MetaDataFormat(dc, Types.FileFormat.BINARY), new_matrix_block); + new_mo.exportData(); // write data, otherwise it might get evicted and thus get lost + return new_mo; + } + + private static byte[][] extractRawData(Encrypted[] data) { + byte[][] raw_data = new byte[data.length][]; + for (int i = 0; i < data.length; i++) { + raw_data[i] = data[i].getData(); + } + return raw_data; + } + + // TODO: extract an interface for this and use it here + private static byte[][] extractRawData(PublicKey[] data) { + byte[][] raw_data = new byte[data.length][]; + for (int i = 0; i < data.length; i++) { + raw_data[i] = data[i].getData(); + } + return raw_data; + } +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CiphertextMatrix.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CiphertextMatrix.java new file mode 100644 index 00000000000..9a43781c9a8 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CiphertextMatrix.java @@ -0,0 +1,20 @@ +package org.apache.sysds.runtime.instructions.cp; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +/** + * This class abstracts over an encrypted matrix of ciphertexts. It stores the data as opaque byte array. The layout is unspecified. + */ +public class CiphertextMatrix extends Encrypted { + private static final long serialVersionUID = 1762936872261940616L; + + public CiphertextMatrix(int[] dims, DataCharacteristics dc, byte[] data) { + super(dims, dc, data, Types.DataType.ENCRYPTED_CIPHER); + } + + @Override + public String getDebugName() { + return "CiphertextMatrix " + getData().hashCode(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/Encrypted.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/Encrypted.java new file mode 100644 index 00000000000..62c00f5d788 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/Encrypted.java @@ -0,0 +1,34 @@ +package org.apache.sysds.runtime.instructions.cp; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +/** + * This class abstracts over an encrypted data. It stores the data as opaque byte array. The layout is unspecified. + */ +public abstract class Encrypted extends Data { + private static final long serialVersionUID = 1762936872268046168L; + + private final int[] _dims; + private final DataCharacteristics _dc; + private final byte[] _data; + + public Encrypted(int[] dims, DataCharacteristics dc, byte[] data, Types.DataType dt) { + super(dt, Types.ValueType.UNKNOWN); + _dims = dims; + _dc = dc; + _data = data; + } + + public int[] getDims() { + return _dims; + } + + public DataCharacteristics getDataCharacteristics() { + return _dc; + } + + public byte[] getData() { + return _data; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java index 38288178e4e..5c302fe80a9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java @@ -397,6 +397,18 @@ public void writeExternal(ObjectOutput out) throws IOException { ScalarObject so = (ScalarObject) d; out.writeObject(so.getStringValue()); break; + case ENCRYPTED_CIPHER: + case ENCRYPTED_PLAIN: + Encrypted e = (Encrypted) d; + int[] dims = e.getDims(); + dc = e.getDataCharacteristics(); + out.writeObject(dims); + out.writeObject(dc.getRows()); + out.writeObject(dc.getCols()); + out.writeObject(dc.getBlocksize()); + out.writeObject(dc.getNonZeros()); + out.writeObject(e.getData()); + break; default: throw new DMLRuntimeException("Unable to serialize datatype " + dataType); } @@ -463,6 +475,21 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept } d = so; break; + case ENCRYPTED_CIPHER: + case ENCRYPTED_PLAIN: + int[] dims = (int[]) in.readObject(); + rows = (long) in.readObject(); + cols = (long) in.readObject(); + blockSize = (int) in.readObject(); + nonZeros = (long) in.readObject(); + byte[] data = (byte[])in.readObject(); + DataCharacteristics dc = new MatrixCharacteristics(rows, cols, blockSize, nonZeros); + if (dataType == DataType.ENCRYPTED_CIPHER) { + d = new CiphertextMatrix(dims, dc, data); + } else { + d = new PlaintextMatrix(dims, dc, data); + } + break; default: throw new DMLRuntimeException("Unable to deserialize datatype " + dataType); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 25353f60039..8c6e4566514 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -71,22 +71,17 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysds.runtime.controlprogram.paramserv.FederatedPSControlThread; -import org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker; -import org.apache.sysds.runtime.controlprogram.paramserv.LocalParamServer; -import org.apache.sysds.runtime.controlprogram.paramserv.ParamServer; -import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSBody; -import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSWorker; -import org.apache.sysds.runtime.controlprogram.paramserv.SparkParamservUtils; +import org.apache.sysds.runtime.controlprogram.paramserv.*; import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme; import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme; import org.apache.sysds.runtime.controlprogram.paramserv.dp.FederatedDataPartitioner; import org.apache.sysds.runtime.controlprogram.paramserv.dp.LocalDataPartitioner; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey; import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcFactory; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; import org.apache.sysds.runtime.util.ProgramConverter; import org.apache.sysds.utils.stats.ParamServStatistics; @@ -188,23 +183,52 @@ private void runFederated(ExecutionContext ec) { MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null) ? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null; MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null; boolean modelAvg = Boolean.parseBoolean(getParam(PS_MODELAVG)); - ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, freq, workerNum, model, aggServiceEC, getValFunction(), - getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics), val_features, val_labels, nbatches, modelAvg); + + // check if we need homomorphic encryption + boolean use_homomorphic_encryption_ = false; + for (int i = 0; i < workerNum; i++) { + use_homomorphic_encryption_ = use_homomorphic_encryption_ || checkIsPrivate(result._pFeatures.get(i)); + use_homomorphic_encryption_ = use_homomorphic_encryption_ || checkIsPrivate(result._pLabels.get(i)); + } + final boolean use_homomorphic_encryption = use_homomorphic_encryption_; + if (use_homomorphic_encryption && !modelAvg) { + throw new DMLRuntimeException("can't use homomorphic encryption without modelAvg"); + } + + LocalParamServer ps = (LocalParamServer)createPS(PSModeType.FEDERATED, aggFunc, updateType, freq, workerNum, model, aggServiceEC, getValFunction(), + getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics), val_features, val_labels, nbatches, modelAvg, use_homomorphic_encryption); // Create the local workers int finalNumBatchesPerEpoch = getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics); List<FederatedPSControlThread> threads = IntStream.range(0, workerNum) .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, weighting, - getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps, nbatches, modelAvg)) + getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps, nbatches, modelAvg, use_homomorphic_encryption)) .collect(Collectors.toList()); if(workerNum != threads.size()) { throw new DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning does not match threads!"); } + // Set features and lables for the control threads and write the program and instructions and hyperparams to the federated workers for (int i = 0; i < threads.size(); i++) { threads.get(i).setFeatures(result._pFeatures.get(i)); threads.get(i).setLabels(result._pLabels.get(i)); threads.get(i).setup(result._weightingFactors.get(i)); } + + if (use_homomorphic_encryption) { + // generate public key from partial public keys + PublicKey[] partial_public_keys = new PublicKey[threads.size()]; + for (int i = 0; i < threads.size(); i++) { + partial_public_keys[i] = threads.get(i).getPartialPublicKey(); + } + + // TODO: accumulate public keys with SEAL + PublicKey public_key = ((HEParamServer)ps).aggregatePartialPublicKeys(partial_public_keys); + + for (FederatedPSControlThread thread : threads) { + thread.setPublicKey(public_key); + } + } + if (DMLScript.STATISTICS) ParamServStatistics.accSetupTime((long) tSetup.stop()); @@ -479,21 +503,32 @@ private int getWorkerNum(PSModeType mode) { * @return parameter server */ private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, - PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, int nbatches, boolean modelAvg) + PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, int nbatches, boolean modelAvg) { return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, null, -1, null, null, nbatches, modelAvg); } + + private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, + PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, String valFunc, + int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg) { + return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg, false); + } + // When this creation is used the parameter server is able to validate after each epoch private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, String valFunc, - int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg) + int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg, boolean use_homomorphic_encryption) { switch (mode) { case FEDERATED: case LOCAL: case REMOTE_SPARK: - return LocalParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg); + if (use_homomorphic_encryption) { + return HEParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches); + } else { + return LocalParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg); + } default: throw new DMLRuntimeException("Unsupported parameter server: " + mode.name()); } @@ -614,4 +649,9 @@ private int getNbatches() { } return Integer.parseInt(getParam(PS_NBATCHES)); } + + private boolean checkIsPrivate(MatrixObject obj) { + PrivacyConstraint pc = obj.getPrivacyConstraint(); + return pc != null && pc.hasPrivateElements(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java new file mode 100644 index 00000000000..0c745fa0aaa --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java @@ -0,0 +1,20 @@ +package org.apache.sysds.runtime.instructions.cp; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +/** + * This class abstracts over an encrypted matrix of ciphertexts. It stores the data as opaque byte array. The layout is unspecified. + */ +public class PlaintextMatrix extends Encrypted { + private static final long serialVersionUID = 5732436872261940616L; + + public PlaintextMatrix(int[] dims, DataCharacteristics dc, byte[] data) { + super(dims, dc, data, Types.DataType.ENCRYPTED_PLAIN); + } + + @Override + public String getDebugName() { + return "PlaintextMatrix " + getData().hashCode(); + } +} diff --git a/src/main/java/org/apache/sysds/utils/NativeHelper.java b/src/main/java/org/apache/sysds/utils/NativeHelper.java index 83869c23d2f..72eea72b475 100644 --- a/src/main/java/org/apache/sysds/utils/NativeHelper.java +++ b/src/main/java/org/apache/sysds/utils/NativeHelper.java @@ -44,14 +44,14 @@ * By default, it first tries to load Intel MKL, else tries to load OpenBLAS. */ public class NativeHelper { - + public enum NativeBlasState { NOT_ATTEMPTED_LOADING_NATIVE_BLAS, SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE, SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE, ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY } - + public static NativeBlasState CURRENT_NATIVE_BLAS_STATE = NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS; private static String blasType; @@ -63,16 +63,16 @@ public enum NativeBlasState { /** * Called by Statistics to print the loaded BLAS. - * + * * @return empty string or the BLAS that is loaded */ public static String getCurrentBLAS() { return blasType != null ? blasType : ""; } - + /** * Called by runtime to check if the BLAS is available for exploitation - * + * * @return true if CURRENT_NATIVE_BLAS_STATE is SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE else false */ public static boolean isNativeLibraryLoaded() { @@ -99,10 +99,10 @@ public static boolean isNativeLibraryLoaded() { } return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE; } - + /** - * Initialize the native library before executing the DML program - * + * Initialize the native library before executing the DML program + * * @param customLibPath specified by sysds.native.blas.directory * @param userSpecifiedBLAS specified by sysds.native.blas */ @@ -121,22 +121,22 @@ else if(!isBLASLoaded() && isSupportedBLAS(userSpecifiedBLAS)) { performLoading(customLibPath, userSpecifiedBLAS); } } - + /** * Return true if the given BLAS type is supported. - * + * * @param userSpecifiedBLAS BLAS type specified via sysds.native.blas property * @return true if the userSpecifiedBLAS is auto | mkl | openblas, else false */ private static boolean isSupportedBLAS(String userSpecifiedBLAS) { - return userSpecifiedBLAS.equalsIgnoreCase("auto") || - userSpecifiedBLAS.equalsIgnoreCase("mkl") || + return userSpecifiedBLAS.equalsIgnoreCase("auto") || + userSpecifiedBLAS.equalsIgnoreCase("mkl") || userSpecifiedBLAS.equalsIgnoreCase("openblas"); } - + /** * Note: we only support 64 bit Java on x86 and AMD machine - * + * * @return true if the hardware architecture is supported */ private static boolean isSupportedArchitecture() { @@ -166,21 +166,21 @@ private static boolean isSupportedOS() { * SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE */ private static boolean isBLASLoaded() { - return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE || + return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE || CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE; } - + /** * Check if we should attempt to perform loading. * If custom library path is provided, we should attempt to load again if not already loaded. - * - * @param customLibPath custom library path + * + * @param customLibPath custom library path * @return true if we should attempt to load blas again */ private static boolean shouldReload(String customLibPath) { boolean isValidBLASDirectory = customLibPath != null && !customLibPath.equalsIgnoreCase("none"); return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS || - (isValidBLASDirectory && !isBLASLoaded()); + (isValidBLASDirectory && !isBLASLoaded()); } // Performing loading in a method instead of a static block will throw a detailed stack trace in case of fatal errors @@ -191,13 +191,13 @@ private static void performLoading(String customLibPath, String userSpecifiedBLA // attemptedLoading variable ensures that we don't try to load SystemDS and other dependencies // again and again especially in the parfor (hence the double-checking with synchronized). if(shouldReload(customLibPath) && isSupportedBLAS(userSpecifiedBLAS) && isSupportedArchitecture() - && isSupportedOS()) { + && isSupportedOS()) { long start = System.nanoTime(); synchronized(NativeHelper.class) { if(shouldReload(customLibPath)) { // Set attempted loading unsuccessful in case of exception CURRENT_NATIVE_BLAS_STATE = NativeBlasState.ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY; - String [] blas = new String[] { userSpecifiedBLAS }; + String [] blas = new String[] { userSpecifiedBLAS }; if(userSpecifiedBLAS.equalsIgnoreCase("auto")) { blas = new String[] { "mkl", "openblas" }; } @@ -206,7 +206,7 @@ && isSupportedOS()) { String platform_suffix = (SystemUtils.IS_OS_WINDOWS ? "-Windows-AMD64.dll" : "-Linux-x86_64.so"); String library_name = "libsystemds_" + blasType + platform_suffix; if(loadLibraryHelperFromResource(library_name) || - loadBLAS(customLibPath, library_name,"Loading native helper with customLibPath.")) + loadBLAS(customLibPath, library_name,"Loading native helper with customLibPath.")) { LOG.info("Using native blas: " + blasType + getNativeBLASPath()); CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE; @@ -215,15 +215,15 @@ && isSupportedOS()) { } } double timeToLoadInMilliseconds = (System.nanoTime()-start)*1e-6; - if(timeToLoadInMilliseconds > 1000) + if(timeToLoadInMilliseconds > 1000) LOG.warn("Time to load native blas: " + timeToLoadInMilliseconds + " milliseconds."); } else if(LOG.isDebugEnabled() && !isSupportedBLAS(userSpecifiedBLAS)) { LOG.debug("Using internal Java BLAS as native BLAS support instead of the configuration " + - "'sysds.native.blas'=" + userSpecifiedBLAS + "."); + "'sysds.native.blas'=" + userSpecifiedBLAS + "."); } } - + private static boolean checkAndLoadBLAS(String customLibPath, String [] listBLAS) { if(customLibPath != null && customLibPath.equalsIgnoreCase("none")) customLibPath = null; @@ -250,10 +250,10 @@ else if (blas.equalsIgnoreCase("openblas")) { } return isLoaded; } - + /** * Useful method for debugging. - * + * * @return empty string (if !LOG.isDebugEnabled()) or the path from where openblas or mkl is loaded. */ private static String getNativeBLASPath() { @@ -287,8 +287,8 @@ public static int getMaxNumThreads() { /** * Attempts to load native BLAS - * - * @param customLibPath can be null (if we want to only want to use LD_LIBRARY_PATH), else the + * + * @param customLibPath can be null (if we want to only want to use LD_LIBRARY_PATH), else the * @param blas can be gomp, openblas or mkl_rt * @param optionalMsg message for debugging * @return true if successfully loaded BLAS @@ -300,8 +300,8 @@ public static boolean loadBLAS(String customLibPath, String blas, String optiona try { // This fixes libPath if it already contained a prefix/suffix and mapLibraryName added another one. libPath = libPath.replace("liblibsystemds", "libsystemds") - .replace(".dll.dll", ".dll") - .replace(".so.so", ".so"); + .replace(".dll.dll", ".dll") + .replace(".so.so", ".so"); System.load(libPath); LOG.info("Loaded the library:" + libPath); return true; @@ -321,7 +321,7 @@ public static boolean loadBLAS(String customLibPath, String blas, String optiona catch (UnsatisfiedLinkError e) { LOG.debug("java.library.path: " + System.getProperty("java.library.path")); LOG.debug("Unable to load " + blas + (optionalMsg == null ? "" : (" (" + optionalMsg + ")")) + - " \n Message from exception was: " + e.getMessage()); + " \n Message from exception was: " + e.getMessage()); return false; } } @@ -355,13 +355,13 @@ public static boolean loadLibraryHelperFromResource(String libFileName) { } // TODO: Add pmm, wsloss, mmchain, etc. - + //double-precision matrix multiply dense-dense public static native long dmmdd(double [] m1, double [] m2, double [] ret, int m1rlen, int m1clen, int m2clen, - int numThreads); + int numThreads); //single-precision matrix multiply dense-dense public static native long smmdd(FloatBuffer m1, FloatBuffer m2, FloatBuffer ret, int m1rlen, int m1clen, int m2clen, - int numThreads); + int numThreads); //transpose-self matrix multiply public static native long tsmm(double[] m1, double[] ret, int m1rlen, int m1clen, boolean leftTrans, int numThreads); @@ -374,27 +374,27 @@ public static native long smmdd(FloatBuffer m1, FloatBuffer m2, FloatBuffer ret, // Returns -1 if failures or returns number of nonzeros // Called by DnnCPInstruction if both input and filter are dense - public static native long conv2dDense(double [] input, double [] filter, double [] ret, int N, int C, int H, int W, - int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads); + public static native long conv2dDense(double [] input, double [] filter, double [] ret, int N, int C, int H, int W, + int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads); public static native long dconv2dBiasAddDense(double [] input, double [] bias, double [] filter, double [] ret, int N, - int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, - int numThreads); + int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, + int numThreads); public static native long sconv2dBiasAddDense(FloatBuffer input, FloatBuffer bias, FloatBuffer filter, FloatBuffer ret, - int N, int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, - int numThreads); + int N, int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, + int numThreads); // Called by DnnCPInstruction if both input and filter are dense public static native long conv2dBackwardFilterDense(double [] input, double [] dout, double [] ret, int N, int C, - int H, int W, int K, int R, int S, int stride_h, int stride_w, - int pad_h, int pad_w, int P, int Q, int numThreads); + int H, int W, int K, int R, int S, int stride_h, int stride_w, + int pad_h, int pad_w, int P, int Q, int numThreads); // If both filter and dout are dense, then called by DnnCPInstruction // Else, called by LibMatrixDNN's thread if filter is dense. dout[n] is converted to dense if sparse. public static native long conv2dBackwardDataDense(double [] filter, double [] dout, double [] ret, int N, int C, - int H, int W, int K, int R, int S, int stride_h, int stride_w, - int pad_h, int pad_w, int P, int Q, int numThreads); + int H, int W, int K, int R, int S, int stride_h, int stride_w, + int pad_h, int pad_w, int P, int Q, int numThreads); // Currently only supported with numThreads = 1 and sparse input // Called by LibMatrixDNN's thread if input is sparse. dout[n] is converted to dense if sparse. @@ -415,4 +415,50 @@ public static native boolean conv2dSparse(int apos, int alen, int[] aix, double[ // different tradeoffs. In current implementation, we always use GetPrimitiveArrayCritical as it has proven to be // fastest. We can revisit this decision later and hence I would not recommend removing this method. private static native void setMaxNumThreads(int numThreads); + + // ---------------------------------------------------------------------------------------------------------------- + // SEAL integration + // ---------------------------------------------------------------------------------------------------------------- + + // these are called by SEALClient + + // generates a Client object and returns a pointer to it as long + public static native long initClient(byte[] a); + + // generates a partial public key and returns it + // stores a partial private key corresponding to the partial public key in client + public static native byte[] generatePartialPublicKey(long client); + + // sets the public key and stores it in client + public static native void setPublicKey(long client, byte[] public_key); + + // encrypts data with public key stored in client and returns it + // setPublicKey() must have been called before calling this + public static native byte[] encrypt(long client, double[] plaintexts); + + // partially decrypts ciphertexts with the partial private key. generatePartialPublicKey() must + // have been called before calling this function + public static native byte[] partiallyDecrypt(long client, byte[] ciphertexts); + + // ---------------------------------------------------------------------------------------------------------------- + + // these are called by SEALServer + + // generates the Server object and returns a pointer to it as long + public static native long initServer(); + + // this generates the a constant. in a future version we want to generate this together with the clients to prevent misuse + public static native byte[] generateA(long server); + + // accumulates the given partial public keys into a public key, stores it in server and returns it + public static native byte[] aggregatePartialPublicKeys(long server, byte[][] partial_public_keys); + + // accumulates the given ciphertexts into a sum ciphertext and returns it + // stores c0 of the sum to be used in averageBlocks() + public static native byte[] accumulateCiphertexts(long server, byte[][] ciphertexts); + + // averages the partial decryptions and returns the result + // encrypted_sum is the result of accumulateCiphertexts() and partial_plaintexts is the result of partiallyDecrypt + // of each ciphertext fed into accumulateCiphertexts + public static native double[] average(long server, byte[] encrypted_sum, byte[][] partial_plaintexts); } diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index d960dac327a..942c5455008 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -652,6 +652,12 @@ protected void writeFederatedInputMatrix(String name, FederationMap fedMap){ */ protected void rowFederateLocallyAndWriteInputMatrixWithMTD(String name, double[][] matrix, int numFederatedWorkers, List<Integer> ports, double[][] ranges) + { + rowFederateLocallyAndWriteInputMatrixWithMTD(name, matrix, numFederatedWorkers, ports, ranges, null); + } + + protected void rowFederateLocallyAndWriteInputMatrixWithMTD(String name, + double[][] matrix, int numFederatedWorkers, List<Integer> ports, double[][] ranges, PrivacyConstraint privacyConstraint) { // check matrix non empty if(matrix.length == 0 || matrix[0].length == 0) @@ -677,7 +683,7 @@ protected void rowFederateLocallyAndWriteInputMatrixWithMTD(String name, // write slice writeInputMatrixWithMTD(path, Arrays.copyOfRange(matrix, (int)lowerBound, (int)upperBound), false, new MatrixCharacteristics((long) examplesForWorkerI, ncol, - OptimizerUtils.DEFAULT_BLOCKSIZE, (long) examplesForWorkerI * ncol)); + OptimizerUtils.DEFAULT_BLOCKSIZE, (long) examplesForWorkerI * ncol), privacyConstraint); // generate fedmap entry FederatedRange range = new FederatedRange(new long[]{(long) lowerBound, 0}, new long[]{(long) upperBound, ncol}); @@ -688,7 +694,7 @@ false, new MatrixCharacteristics((long) examplesForWorkerI, ncol, federatedMatrixObject.setFedMapping(new FederationMap(FederationUtils.getNextFedDataID(), fedHashMap)); federatedMatrixObject.getFedMapping().setType(FType.ROW); - writeInputFederatedWithMTD(name, federatedMatrixObject, null); + writeInputFederatedWithMTD(name, federatedMatrixObject, privacyConstraint); } protected double[][] generateBalancedFederatedRowRanges(int numFederatedWorkers, int dataSetSize) { diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java new file mode 100644 index 00000000000..28ac6c94058 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.paramserv; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.hops.codegen.SpoofCompiler; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.NativeHelper; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class EncryptedFederatedParamservTest extends AutomatedTestBase { + // private static final Log LOG = LogFactory.getLog(EncryptedFederatedParamservTest.class.getName()); + private final static String TEST_DIR = "functions/federated/paramserv/"; + private final static String TEST_NAME = "EncryptedFederatedParamservTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + EncryptedFederatedParamservTest.class.getSimpleName() + "/"; + + private final String _networkType; + private final int _numFederatedWorkers; + private final int _dataSetSize; + private final int _epochs; + private final int _batch_size; + private final double _eta; + private final String _utype; + private final String _freq; + private final String _scheme; + private final String _runtime_balancing; + private final String _weighting; + private final String _data_distribution; + private final int _seed; + + // parameters + @Parameterized.Parameters + public static Collection<Object[]> parameters() { + return Arrays.asList(new Object[][] { + // Network type, number of federated workers, data set size, batch size, epochs, learning rate, update type, update frequency + // basic functionality + //{"TwoNN", 4, 60000, 32, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "false","BALANCED", 200}, + + // One important point is that we do the model averaging in the case of BSP + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "BASELINE", "true", "IMBALANCED", 200}, + {"CNN", 2, 4, 1, 1, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "BASELINE", "true", "IMBALANCED", 200}, + //{"TwoNN", 5, 1000, 100, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE", "true", "BALANCED", 200}, + + /* + // runtime balancing + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MIN", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MIN", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX", "true", "IMBALANCED", 200}, + + // data partitioning + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE_AVG", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE_TO_MAX", "NONE", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "BALANCE_TO_AVG", "NONE", "true", "IMBALANCED", 200}, + + // balanced tests + {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE", "true", "BALANCED", 200} + */ + }); + } + + public EncryptedFederatedParamservTest(String networkType, int numFederatedWorkers, int dataSetSize, int batch_size, + int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String weighting, String data_distribution, int seed) { + try { + NativeHelper.initialize("none", "openblas"); + } catch (Exception e) { + throw e; + } + _networkType = networkType; + _numFederatedWorkers = numFederatedWorkers; + _dataSetSize = dataSetSize; + _batch_size = batch_size; + _epochs = epochs; + _eta = eta; + _utype = utype; + _freq = freq; + _scheme = scheme; + _runtime_balancing = runtime_balancing; + _weighting = weighting; + _data_distribution = data_distribution; + _seed = seed; + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Test + public void EncryptedfederatedParamservSingleNode() { + EncryptedfederatedParamserv(ExecMode.SINGLE_NODE, true); + } + + @Test + public void EncryptedfederatedParamservHybrid() { + EncryptedfederatedParamserv(ExecMode.HYBRID, true); + } + + private void EncryptedfederatedParamserv(ExecMode mode, boolean modelAvg) { + // Warning Statistics accumulate in unit test + // config + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + setOutputBuffering(true); + + int C = 1, Hin = 28, Win = 28; + int numLabels = 10; + + ExecMode platformOld = setExecMode(mode); + + try { + // start threads + List<Integer> ports = new ArrayList<>(); + List<Thread> threads = new ArrayList<>(); + for(int i = 0; i < _numFederatedWorkers; i++) { + ports.add(getRandomAvailablePort()); + threads.add(startLocalFedWorkerThread(ports.get(i), + i==(_numFederatedWorkers-1) ? FED_WORKER_WAIT : FED_WORKER_WAIT_S)); + } + + // generate test data + double[][] features = generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win); + double[][] labels = generateDummyMNISTLabels(_dataSetSize, numLabels); + String featuresName = ""; + String labelsName = ""; + + PrivacyConstraint privacyConstraint = new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.Private); + + // federate test data balanced or imbalanced + if(_data_distribution.equals("IMBALANCED")) { + featuresName = "X_IMBALANCED_" + _numFederatedWorkers; + labelsName = "y_IMBALANCED_" + _numFederatedWorkers; + double[][] ranges = {{0,1}, {1,4}}; + rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, ranges, privacyConstraint); + rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, ranges, privacyConstraint); + } + else { + featuresName = "X_BALANCED_" + _numFederatedWorkers; + labelsName = "y_BALANCED_" + _numFederatedWorkers; + double[][] ranges = generateBalancedFederatedRowRanges(_numFederatedWorkers, features.length); + rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, ranges, privacyConstraint); + rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, ranges, privacyConstraint); + } + + try { + //wait for all workers to be setup + Thread.sleep(FED_WORKER_WAIT); + } + catch(InterruptedException e) { + e.printStackTrace(); + } + + // dml name + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + // generate program args + List<String> programArgsList = new ArrayList<>(Arrays.asList("-stats", + "-nvargs", + "features=" + input(featuresName), + "labels=" + input(labelsName), + "epochs=" + _epochs, + "batch_size=" + _batch_size, + "eta=" + _eta, + "utype=" + _utype, + "freq=" + _freq, + "scheme=" + _scheme, + "runtime_balancing=" + _runtime_balancing, + "weighting=" + _weighting, + "network_type=" + _networkType, + "channels=" + C, + "hin=" + Hin, + "win=" + Win, + "seed=" + _seed, + "modelAvg=" + Boolean.toString(modelAvg).toUpperCase())); + + programArgs = programArgsList.toArray(new String[0]); + String log = runTest(null).toString(); + Assert.assertEquals("Test Failed \n" + log, 0, Statistics.getNoOfExecutedSPInst()); + + // shut down threads + for(int i = 0; i < _numFederatedWorkers; i++) { + TestUtils.shutdownThreads(threads.get(i)); + } + } + finally { + resetExecMode(platformOld); + } + } + + /** + * Generates an feature matrix that has the same format as the MNIST dataset, + * but is completely random and normalized + * + * @param numExamples Number of examples to generate + * @param C Channels in the input data + * @param Hin Height in Pixels of the input data + * @param Win Width in Pixels of the input data + * @return a dummy MNIST feature matrix + */ + private double[][] generateDummyMNISTFeatures(int numExamples, int C, int Hin, int Win) { + // Seed -1 takes the time in milliseconds as a seed + // Sparsity 1 means no sparsity + return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1); + } + + /** + * Generates an label matrix that has the same format as the MNIST dataset, but is completely random and consists + * of one hot encoded vectors as rows + * + * @param numExamples Number of examples to generate + * @param numLabels Number of labels to generate + * @return a dummy MNIST lable matrix + */ + private double[][] generateDummyMNISTLabels(int numExamples, int numLabels) { + // Seed -1 takes the time in milliseconds as a seed + // Sparsity 1 means no sparsity + return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/homomorphicEncryption/InOutTest.java b/src/test/java/org/apache/sysds/test/functions/homomorphicEncryption/InOutTest.java new file mode 100644 index 00000000000..366ace77921 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/homomorphicEncryption/InOutTest.java @@ -0,0 +1,98 @@ +package org.apache.sysds.test.functions.homomorphicEncryption; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALServer; +import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix; +import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataFormat; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.NativeHelper; +import org.junit.Test; + +public class InOutTest extends AutomatedTestBase { + private final static String TEST_NAME = "InOutTest"; + private final static String TEST_DIR = "functions/data/"; + private final static String TEST_CLASS_DIR = TEST_DIR + InOutTest.class.getSimpleName() + "/"; + + private final int num_clients = 3; + + private final int rows = 100; + private final int cols = 200; + private final long seed = 42; + + @Override + public void setUp() { + try { + NativeHelper.initialize("none", "openblas"); + } catch (Exception e) { + throw e; + } + + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "C" }) ); + } + + @Test + public void endToEndTest() { + SEALServer server = new SEALServer(); + + SEALClient[] clients = new SEALClient[num_clients]; + PublicKey[] partial_pub_keys = new PublicKey[num_clients]; + for (int i = 0; i < num_clients; i++) { + clients[i] = new SEALClient(server.generateA()); + partial_pub_keys[i] = clients[i].generatePartialPublicKey(); + } + + PublicKey public_key = server.aggregatePartialPublicKeys(partial_pub_keys); + + MatrixObject[] plaintexts = new MatrixObject[num_clients]; + CiphertextMatrix[] ciphertexts = new CiphertextMatrix[num_clients]; + for (int i = 0; i < num_clients; i++) { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(rows, cols, -100, 100, 1.0, seed+i); + MatrixObject mo = new MatrixObject(Types.ValueType.FP64, null); + mo.setMetaData(new MetaDataFormat(new MatrixCharacteristics(rows, cols), Types.FileFormat.BINARY)); + mo.acquireModify(mb); + mo.release(); + plaintexts[i] = mo; + + clients[i].setPublicKey(public_key); + ciphertexts[i] = clients[i].encrypt(plaintexts[i]); + } + + CiphertextMatrix encrypted_sum = server.accumulateCiphertexts(ciphertexts); + + PlaintextMatrix[] partial_decryptions = new PlaintextMatrix[num_clients]; + for (int i = 0; i < num_clients; i++) { + partial_decryptions[i] = clients[i].partiallyDecrypt(encrypted_sum); + } + + MatrixObject result = server.average(encrypted_sum, partial_decryptions); + + double[] expected_raw_result = new double[rows*cols]; + double[][] plaintexts_raw = new double[num_clients][]; + for (int i = 0; i < num_clients; i++) { + plaintexts_raw[i] = plaintexts[i].acquireReadAndRelease().getDenseBlockValues(); + } + for (int x = 0; x < rows * cols; x++) { + double sum = 0.0; + for (int i = 0; i < num_clients; i++) { + sum += plaintexts_raw[i][x]; + } + expected_raw_result[x] = sum / num_clients; + } + + double[] raw_result = result.acquireReadAndRelease().getDenseBlockValues(); + assert result.getNumRows() == rows; + assert result.getNumColumns() == cols; + assert raw_result.length == rows*cols; + TestUtils.compareMatrices(raw_result, expected_raw_result, 5e-8); + } +} diff --git a/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml new file mode 100644 index 00000000000..b8021867dc9 --- /dev/null +++ b/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml @@ -0,0 +1,61 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN +source("src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml") as TwoNNModelAvg +source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN +source("src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml") as CNNModelAvg + + +# create federated input matrices +features = read($features) +labels = read($labels) + +if($network_type == "TwoNN") { + if(!as.logical($modelAvg)) { + model = TwoNN::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed) + print("Test results:") + [loss_test, accuracy_test] = TwoNN::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, list()) + print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n") + } + else if (as.logical($modelAvg)){ + model = TwoNNModelAvg::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed, $modelAvg) + print("Test results:") + [loss_test, accuracy_test] = TwoNNModelAvg::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, list()) + print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n") + } +} +else if($network_type == "CNN") { + if(!as.logical($modelAvg)) { + model = CNN::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, $win, $seed) + print("Test results:") + hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win) + [loss_test, accuracy_test] = CNN::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, hyperparams) + print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n") + } + else if (as.logical($modelAvg)){ + model = CNNModelAvg::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, $win, $seed, $modelAvg) + print("Test results:") + hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win) + [loss_test, accuracy_test] = CNNModelAvg::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, hyperparams) + print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n") + } +}