diff --git a/src/main/cpp/build.bat b/src/main/cpp/build.bat index 93c3819c3be..316c4f393df 100644 --- a/src/main/cpp/build.bat +++ b/src/main/cpp/build.bat @@ -34,5 +34,9 @@ cmake . -B OPENBLAS -DUSE_OPEN_BLAS=ON -DCMAKE_BUILD_TYPE=Release cmake --build OPENBLAS --target install --config Release rmdir /Q /S OPENBLAS +cmake he\ -B HE -DCMAKE_BUILD_TYPE=Release +cmake --build HE --target install --config Release +rmdir /Q /S HE + echo. echo "Make sure to re-run mvn package to make use of the newly compiled libraries" \ No newline at end of file diff --git a/src/main/cpp/build.sh b/src/main/cpp/build.sh index df67aba5392..e40ec895a3c 100755 --- a/src/main/cpp/build.sh +++ b/src/main/cpp/build.sh @@ -66,3 +66,8 @@ ldd lib/libsystemds_mkl-Linux-x86_64.so | grep -v $gcc_toolkit"\|$linux_loader\| echo "Non-standard dependencies for libsystemds_openblas-linux-x86_64.so" ldd lib/libsystemds_openblas-Linux-x86_64.so | grep -v $gcc_toolkit"\|$linux_loader\|"$openblas echo "-----------------------------------------------------------------------" + +# compile HE +cmake he/ -B HE -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=g++ +cmake --build HE --target install --config Release +rm -R HE \ No newline at end of file diff --git a/src/main/cpp/he/CMakeLists.txt b/src/main/cpp/he/CMakeLists.txt new file mode 100644 index 00000000000..373ba3a5d99 --- /dev/null +++ b/src/main/cpp/he/CMakeLists.txt @@ -0,0 +1,64 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +cmake_minimum_required(VERSION 3.8) +cmake_policy(SET CMP0074 NEW) # make use of _ROOT variable +project (he LANGUAGES CXX) + +# All custom find modules +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/") + +# Build a shared libraray +set(HEADER_FILES libhe.h he.h) +set(SOURCE_FILES libhe.cpp he.cpp) + +# Build a shared libraray +add_library(he SHARED ${SOURCE_FILES} ${HEADER_FILES}) + +set_target_properties(he PROPERTIES MACOSX_RPATH 1) + +# sets the installation path to src/main/cpp/lib +if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + set(CMAKE_INSTALL_PREFIX "${CMAKE_SOURCE_DIR}/.." CACHE PATH "sets the installation path to src/main/cpp/lib" FORCE) +endif() + +# sets the installation path to src/main/cpp/lib +# install(TARGETS he LIBRARY DESTINATION lib) +install(TARGETS he RUNTIME DESTINATION lib) + +# unify library filenames to libhe_<...> +if (WIN32) + set(CMAKE_IMPORT_LIBRARY_PREFIX lib CACHE INTERNAL "") + set(CMAKE_SHARED_LIBRARY_PREFIX lib CACHE INTERNAL "") +endif() + +set(CMAKE_BUILD_TYPE Release) +set_target_properties(he PROPERTIES OUTPUT_NAME "he-${CMAKE_SYSTEM_NAME}-${CMAKE_SYSTEM_PROCESSOR}") + +find_package(SEAL 3.7 REQUIRED) +target_link_libraries(he SEAL::seal_shared) + +# Include directories. (added for Linux & Darwin, fix later for windows) +# include paths can be spurious +include_directories($ENV{JAVA_HOME}/include/) +include_directories($ENV{JAVA_HOME}/include/darwin) +include_directories($ENV{JAVA_HOME}/include/linux) +include_directories($ENV{JAVA_HOME}/include/win32) diff --git a/src/main/cpp/he/he.cpp b/src/main/cpp/he/he.cpp new file mode 100644 index 00000000000..f9bad7e9846 --- /dev/null +++ b/src/main/cpp/he/he.cpp @@ -0,0 +1,279 @@ +#include "he.h" +#include "libhe.h" + +#ifdef _WIN32 +#include +#else +#include +#endif + +unique_ptr get_stream(JNIEnv* env, jbyteArray ary) { + size_t size = env->GetArrayLength(ary); + jbyte* data = env->GetByteArrayElements(ary, NULL); + + // FIXME: this copies string data once. maybe implement a custom stream + // idea: implement a custom stream that wraps a jbyteArray, which calls ReleaseByteArrayElements in its d'tor + string data_s = string(reinterpret_cast(data), size); + unique_ptr ret = std::make_unique(std::move(data_s)); + env->ReleaseByteArrayElements(ary, data, JNI_ABORT); + return ret; +} + +jbyteArray allocate_byte_array(JNIEnv* env, ostringstream& stream) { + string data = stream.str(); // FIXME: this copies string content. maybe implement custom ostream + jbyteArray ret = env->NewByteArray(data.size()); + env->SetByteArrayRegion(ret, 0, data.size(), reinterpret_cast(data.data())); + return ret; +} + +void my_assert(bool assertion, const char* message = "Assertion failed") { + if (!assertion) { + throw logic_error(message); + } +} + +template jbyteArray serialize(JNIEnv* env, T& object) { + ostringstream ss; + object.save(ss); + return allocate_byte_array(env, ss); +} + +void serialize_uint32_t(ostream& ss, uint32_t n) { + n = htonl(n); + ss.write(reinterpret_cast(&n), sizeof(n)); +} + +uint32_t deserialize_uint32_t(istream& ss) { + uint32_t ret; + ss.read(reinterpret_cast(&ret), sizeof(ret)); + ret = ntohl(ret); + return ret; +} + +Ciphertext deserialize_ciphertext(istream& ss, const SEALContext& context) { + Ciphertext ret; + ret.load(context, ss); + return ret; +} + +void serialize_plaintext(ostream& ss, Plaintext plaintext) { + plaintext.save(ss); +} + +template T deserialize_unsafe(JNIEnv* env, const SEALContext& context, jbyteArray serialized_object) { + auto ss = get_stream(env, serialized_object); + T deserialized; + deserialized.unsafe_load(context, *ss); // necessary bc partial public keys are not valid public keys + return deserialized; +} + +template T deserialize(JNIEnv* env, const SEALContext& context, jbyteArray serialized_object) { + auto ss = get_stream(env, serialized_object); + T deserialized; + deserialized.load(context, *ss); // necessary bc partial public keys are not valid public keys + return deserialized; +} + +JNIEXPORT jlong JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_initClient + (JNIEnv* env, jclass, jbyteArray a_ary) { + double scale = pow(2.0, 40); + GlobalState gs(scale); + + // copy a to global state + size_t byte_size = env->GetArrayLength(a_ary); + my_assert(byte_size % sizeof(uint64_t) == 0); + size_t size = byte_size / sizeof(uint64_t); + uint64_t* a = reinterpret_cast(env->GetByteArrayElements(a_ary, NULL)); + gsl::span new_a(a, size); + + vector new_a_buf; + new_a_buf.assign(new_a.begin(), new_a.end()); + gs.a.set_data(new_a_buf); + + // release a without back-copy + env->ReleaseByteArrayElements(a_ary, reinterpret_cast(a), JNI_ABORT); + + Client* client = new Client(gs); + return reinterpret_cast(client); +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_generatePartialPublicKey + (JNIEnv* env, jclass, jlong client_ptr) { + Client* client = reinterpret_cast(client_ptr); + return serialize(env, client->partial_public_key().data()); +} + + +JNIEXPORT void JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_setPublicKey + (JNIEnv* env, jclass, jlong client_ptr, jbyteArray serialized_public_key) { + Client* client = reinterpret_cast(client_ptr); + client->set_public_key(deserialize(env, client->context(), serialized_public_key)); +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_encrypt + (JNIEnv* env, jclass, jlong client_ptr, jdoubleArray jdata) { + Client* client = reinterpret_cast(client_ptr); + size_t slot_count = get_slot_count(client->context()); + size_t num_data = env->GetArrayLength(jdata); + const double* data = static_cast(env->GetDoubleArrayElements(jdata, NULL)); + + std::ostringstream ss; + // write chunk size + uint32_t num_chunks = (num_data - 1) / slot_count + 1; + serialize_uint32_t(ss, num_chunks); + for (size_t i = 0; i < num_chunks; i++) { + size_t offset = slot_count * i; + size_t length = min(slot_count, num_data-offset); + gsl::span data_span(&data[offset], length); + Ciphertext encrypted_chunk = client->encrypted_data(data_span); + encrypted_chunk.save(ss); + } + env->ReleaseDoubleArrayElements(jdata, const_cast(data), JNI_ABORT); + return allocate_byte_array(env, ss); +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_partiallyDecrypt + (JNIEnv* env, jclass, jlong client_ptr, jbyteArray serialized_ciphertexts) { + Client* client = reinterpret_cast(client_ptr); + auto input = get_stream(env, serialized_ciphertexts); + std::ostringstream ss; + + // read num of chunks + uint32_t num_chunks = deserialize_uint32_t(*input); + + // write chunk size + serialize_uint32_t(ss, num_chunks); + for (int i = 0; i < num_chunks; i++) { + Ciphertext ciphertext = deserialize_ciphertext(*input, client->context()); + Plaintext plaintext = client->partial_decryption(ciphertext); + plaintext.save(ss); + } + + return allocate_byte_array(env, ss); +} + + +JNIEXPORT jlong JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_initServer + (JNIEnv *, jclass) { + double scale = pow(2.0, 40); + GlobalState gs(scale); + Server* server = new Server(gs); + return reinterpret_cast(server); +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_generateA + (JNIEnv* env, jclass, jlong server_ptr) { + Server* server = reinterpret_cast(server_ptr); + uint64_t* data = server->a().data(); + size_t size = server->a().size() * sizeof(data[0]) / sizeof(jbyte); + jbyteArray ret = env->NewByteArray(size); + env->SetByteArrayRegion(ret, 0, size, reinterpret_cast(data)); + return ret; +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_aggregatePartialPublicKeys + (JNIEnv* env, jclass, jlong server_ptr, jobjectArray partial_public_keys_serialized) { + Server* server = reinterpret_cast(server_ptr); + size_t num_partial_public_keys = env->GetArrayLength(partial_public_keys_serialized); + std::vector partial_public_keys; + partial_public_keys.reserve(num_partial_public_keys); + + for (int i = 0; i < num_partial_public_keys; i++) { + jbyteArray j_data = static_cast(env->GetObjectArrayElement(partial_public_keys_serialized, i)); + partial_public_keys.push_back(deserialize_unsafe(env, server->context(), j_data)); + env->DeleteLocalRef(j_data); + } + + server->accumulate_partial_public_keys(gsl::span(partial_public_keys)); + return serialize(env, server->public_key()); +} + + +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_accumulateCiphertexts + (JNIEnv* env, jclass, jlong server_ptr, jobjectArray ciphertexts_serialized) { + Server* server = reinterpret_cast(server_ptr); + size_t num_ciphertext_arys = env->GetArrayLength(ciphertexts_serialized); + + // init streams + vector> buf; + buf.reserve(num_ciphertext_arys); + for (int i = 0; i < num_ciphertext_arys; i++) { + jbyteArray j_data = static_cast(env->GetObjectArrayElement(ciphertexts_serialized, i)); + auto stream = get_stream(env, j_data); + buf.emplace_back(std::move(stream)); + env->DeleteLocalRef(j_data); + } + + // read lengths of ciphertext arys and check that they are all the same + uint32_t num_slots = deserialize_uint32_t(*buf[0]); + for (int i = 1; i < num_ciphertext_arys; i++) { + my_assert(deserialize_uint32_t(*buf[i]) == num_slots); + } + + // read ciphertexts in chunks and accumulate them + ostringstream result; + serialize_uint32_t(result, num_slots); + for (int chunk_idx = 0; chunk_idx < num_slots; chunk_idx++) { + vector ciphertexts; + ciphertexts.reserve(num_ciphertext_arys); + for (int i = 0; i < num_ciphertext_arys; i++) { + Ciphertext deserialized; + deserialized.load(server->context(), *buf[i]); + ciphertexts.emplace_back(deserialized); + } + Ciphertext sum = server->sum_data(std::move(ciphertexts)); + sum.save(result); + } + + return allocate_byte_array(env, result); +} + + +JNIEXPORT jdoubleArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_average + (JNIEnv* env, jclass, jlong server_ptr, jbyteArray ciphertext_sum_serialized, jobjectArray partial_decryptions_serialized) { + Server* server = reinterpret_cast(server_ptr); + size_t slot_size = get_slot_count(server->context()); + size_t num_plaintext_arys = env->GetArrayLength(partial_decryptions_serialized); + + // init streams + vector> buf; + buf.reserve(num_plaintext_arys); + for (int i = 0; i < num_plaintext_arys; i++) { + jbyteArray j_data = static_cast(env->GetObjectArrayElement(partial_decryptions_serialized, i)); + auto stream = get_stream(env, j_data); + buf.emplace_back(std::move(stream)); + env->DeleteLocalRef(j_data); + } + + // read lengths of ciphertext arys and check that they are all the same + uint32_t num_slots = deserialize_uint32_t(*buf[0]); + for (int i = 1; i < num_plaintext_arys; i++) { + my_assert(deserialize_uint32_t(*buf[i]) == num_slots, "number of plaintext slots is different"); + } + + auto encrypted_sum_stream = get_stream(env, ciphertext_sum_serialized); + my_assert(deserialize_uint32_t(*encrypted_sum_stream) == num_slots, "number of ciphertext slots is different"); + + // read ciphertexts in chunks and accumulate them + jdoubleArray result = env->NewDoubleArray(num_slots * slot_size); + for (int chunk_idx = 0; chunk_idx < num_slots; chunk_idx++) { + Ciphertext encrypted_sum = deserialize_ciphertext(*encrypted_sum_stream, server->context()); + + vector partial_decryptions; + partial_decryptions.reserve(num_plaintext_arys); + for (int i = 0; i < num_plaintext_arys; i++) { + Plaintext deserialized; + deserialized.load(server->context(), *buf[i]); + partial_decryptions.emplace_back(deserialized); + } + vector<double> averages = server->average(encrypted_sum, move(partial_decryptions)); + env->SetDoubleArrayRegion(result, chunk_idx*slot_size, averages.size(), averages.data()); + } + + return result; +} \ No newline at end of file diff --git a/src/main/cpp/he/he.h b/src/main/cpp/he/he.h new file mode 100644 index 00000000000..c7b0ad05d5f --- /dev/null +++ b/src/main/cpp/he/he.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <jni.h> +/* Header for class org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper */ + +#ifndef _Included_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper +#define _Included_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: initClient + * Signature: ([B)J + */ +JNIEXPORT jlong JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_initClient + (JNIEnv *, jclass, jbyteArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: generatePartialPublicKey + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_generatePartialPublicKey + (JNIEnv *, jclass, jlong); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: setPublicKey + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_setPublicKey + (JNIEnv *, jclass, jlong, jbyteArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: encrypt + * Signature: (J[D)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_encrypt + (JNIEnv *, jclass, jlong, jdoubleArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: partiallyDecrypt + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_partiallyDecrypt + (JNIEnv *, jclass, jlong, jbyteArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: initServer + * Signature: ()J + */ +JNIEXPORT jlong JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_initServer + (JNIEnv *, jclass); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: generateA + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_generateA + (JNIEnv *, jclass, jlong); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: aggregatePartialPublicKeys + * Signature: (J[[B)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_aggregatePartialPublicKeys + (JNIEnv *, jclass, jlong, jobjectArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: accumulateCiphertexts + * Signature: (J[[B)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_accumulateCiphertexts + (JNIEnv *, jclass, jlong, jobjectArray); + +/* + * Class: org_apache_sysds_utils_NativeHelper + * Method: average + * Signature: (J[B[[B)[D + */ +JNIEXPORT jdoubleArray JNICALL Java_org_apache_sysds_runtime_controlprogram_paramserv_NativeHEHelper_average + (JNIEnv *, jclass, jlong, jbyteArray, jobjectArray); + +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/src/main/cpp/he/libhe.cpp b/src/main/cpp/he/libhe.cpp new file mode 100644 index 00000000000..5f8a929972c --- /dev/null +++ b/src/main/cpp/he/libhe.cpp @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cassert> +#include <algorithm> +#include <optional> +#include <gsl/span> + +#include "libhe.h" + +#include "seal/seal.h" +#include "seal/util/common.h" +#include "seal/util/rlwe.h" +#include "seal/util/polyarithsmallmod.h" + +using namespace std; +using namespace seal; + +RawPolynomData::RawPolynomData(const SEALContext& context) { + // Extract encryption parameters + auto &context_data = *context.key_context_data(); + auto &parms = context_data.parms(); + auto coeff_modulus = parms.coeff_modulus(); + size_t coeff_modulus_size = coeff_modulus.size(); + size_t coeff_count = parms.poly_modulus_degree(); + _size = util::mul_safe(coeff_count, coeff_modulus_size); +}; + +void RawPolynomData::set_data(vector<uint64_t >& data) { + assert(data.size() == _size); + _data = move(data); +}; + + +gsl::span<Ciphertext::ct_coeff_type > data_span(Ciphertext& c, size_t n) { + size_t poly_size = util::mul_safe(c.poly_modulus_degree(), c.coeff_modulus_size()); + return { c.data(n), poly_size }; +} + +RawPolynomData generate_a(const SEALContext& context) { + auto ciphertext_prng = UniformRandomGeneratorFactory::DefaultFactory()->create(); + + auto &context_data = *context.key_context_data(); + auto &parms = context_data.parms(); + + RawPolynomData rpd(parms); + vector<uint64_t > a_poly_data(rpd.size()); + util::sample_poly_uniform(ciphertext_prng, parms, a_poly_data.data()); + rpd.set_data(a_poly_data); + return rpd; +} + +EncryptionParameters generateParameters() { + EncryptionParameters parms(scheme_type::ckks); + + size_t poly_modulus_degree = 4096; + parms.set_poly_modulus_degree(poly_modulus_degree); + parms.set_coeff_modulus(CoeffModulus::Create(poly_modulus_degree, { 54, 54 })); + return parms; +} + +size_t get_slot_count(const SEALContext& ctx) { + // slot count is only half of it. but every slot can take one complex number or 2 doubles. so in the end we get twice + // the space + return ctx.first_context_data()->parms().poly_modulus_degree(); +} + +// returns a vector filled with random double values between 0 and 1 +vector<double> random_plaintext_data(size_t count) { + // this example is just copied from the CKKS example of SEAL + vector<double> data; + data.reserve(count); + for (size_t i = 0; i < count; i++) + { + data.push_back(sqrt(static_cast<double>(rand()) / RAND_MAX)); + } + return data; +} + +GlobalState::GlobalState(double _scale) : context(generateParameters()), a(generate_a(context)), scale(_scale) {}; + + +PublicKey Client::generate_partial_public_key(const SecretKey &secret_key, const SEALContext &context, RawPolynomData& a) +{ + PublicKey public_key; + Ciphertext& destination = public_key.data(); + + // We use a fresh memory pool with `clear_on_destruction' enabled. + MemoryPoolHandle pool = MemoryManager::GetPool(mm_prof_opt::mm_force_new, true); + + auto &context_data = *context.key_context_data(); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_modulus_size = coeff_modulus.size(); + size_t coeff_count = parms.poly_modulus_degree(); + auto ntt_tables = context_data.small_ntt_tables(); + size_t encrypted_size = 2; + + // If a polynomial is too small to store UniformRandomGeneratorInfo, + // it is best to just disable save_seed. Note that the size needed is + // the size of UniformRandomGeneratorInfo plus one (uint64_t) because + // of an indicator word that indicates a seeded ciphertext. + size_t poly_uint64_count = util::mul_safe(coeff_count, coeff_modulus_size); + + destination.resize(context, context.key_parms_id(), encrypted_size); + destination.is_ntt_form() = true; + destination.scale() = 1.0; + + // Create an instance of a random number generator. We use this for sampling + // a seed for a second PRNG used for sampling u (the seed can be public + // information. This PRNG is also used for sampling the noise/error below. + auto bootstrap_prng = parms.random_generator()->create(); + + // Sample a public seed for generating uniform randomness + prng_seed_type public_prng_seed; + bootstrap_prng->generate(prng_seed_byte_count, reinterpret_cast<seal_byte *>(public_prng_seed.data())); + + // Set up a new default PRNG for expanding u from the seed sampled above + auto ciphertext_prng = UniformRandomGeneratorFactory::DefaultFactory()->create(public_prng_seed); + + // Generate ciphertext: (c[0], c[1]) = ([-(as+e)]_q, a) + uint64_t *c0 = destination.data(); + uint64_t *c1 = destination.data(1); + + // copy a into c1 + assert(a.size() == poly_uint64_count); + copy(a.data(), a.data()+poly_uint64_count, c1); + + // Sample e <-- chi + auto noise(util::allocate_poly(coeff_count, coeff_modulus_size, pool)); + util::SEAL_NOISE_SAMPLER(bootstrap_prng, parms, noise.get()); + + // Calculate -(a*s + e) (mod q) and store in c[0] + for (size_t i = 0; i < coeff_modulus_size; i++) + { + util::dyadic_product_coeffmod( + secret_key.data().data() + i * coeff_count, c1 + i * coeff_count, coeff_count, coeff_modulus[i], + c0 + i * coeff_count); + + // Transform the noise e into NTT representation + ntt_negacyclic_harvey(noise.get() + i * coeff_count, ntt_tables[i]); + + util::add_poly_coeffmod( + noise.get() + i * coeff_count, c0 + i * coeff_count, coeff_count, coeff_modulus[i], + c0 + i * coeff_count); + util::negate_poly_coeffmod(c0 + i * coeff_count, coeff_count, coeff_modulus[i], c0 + i * coeff_count); + } + + public_key.parms_id() = context.key_parms_id(); + return public_key; +} + +Client::Client(GlobalState global_state) : _gs(move(global_state)), _encoder(_gs.context) { + KeyGenerator keygen(_gs.context); + _partial_secret_key = keygen.secret_key(); + _partial_public_key = generate_partial_public_key(_partial_secret_key, _gs.context, _gs.a); +}; + +Ciphertext Client::encrypted_data(gsl::span<const double> plain_data) { + if (!_encryptor) { + _encryptor = make_unique<Encryptor>(_gs.context, *_public_key); + } + + // reinterpret plain data as complex<double> + assert(plain_data.size() % 2 == 0); + gsl::span complex_plain_data(reinterpret_cast<const complex<double>*>(plain_data.data()), plain_data.size() / 2); + + Plaintext plaintext; + encoder().encode(complex_plain_data, _gs.scale, plaintext); + Ciphertext ciphertext; + encryptor().encrypt(plaintext, ciphertext); + return ciphertext; +} + +Plaintext Client::partial_decryption(const Ciphertext& encrypted) { + using namespace seal::util; + + // c = (c0, c1) + // dec(c) = c0+c1*s + // we need: c0 + c1*sum(s[i]) + // so we return c1*s[i]*e[i] and add c0 at the server. e[i] is a noise term necessary for security + + // adapted from Decryptor::decrypt + + auto &context_data = *_gs.context.get_context_data(encrypted.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_modulus_size = coeff_modulus.size(); + size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_modulus_size); + + Plaintext plaintext; + // Since we overwrite destination, we zeroize destination parameters + // This is necessary, otherwise resize will throw an exception. + plaintext.parms_id() = parms_id_zero; + + // Resize destination to appropriate size + plaintext.resize(rns_poly_uint64_count); + + // Do the dot product of encrypted and the secret key array using NTT. + RNSIter destination(plaintext.data(), coeff_count); + ConstRNSIter secret_key_array(_partial_secret_key.data().data(), coeff_count); + ConstRNSIter c1(encrypted.data(1), coeff_count); + + SEAL_ITERATE( + iter(c1, secret_key_array, coeff_modulus, destination), coeff_modulus_size, [&](auto I) { + // put < c_1 * s > mod q in destination + dyadic_product_coeffmod(get<0>(I), get<1>(I), coeff_count, get<2>(I), get<3>(I)); + }); + + // for security we need to introduce noise here + // this part is based on rlwe.cpp:encrypt_zero_symmetric() + auto prng = parms.random_generator()->create(); + MemoryPoolHandle pool = MemoryManager::GetPool(mm_prof_opt::mm_force_new, true); + auto noise(allocate_poly(coeff_count, coeff_modulus_size, pool)); + SEAL_NOISE_SAMPLER(prng, parms, noise.get()); + auto ntt_tables = context_data.small_ntt_tables(); + + for (size_t i = 0; i < coeff_modulus_size; i++) + { + // Transform the noise e into NTT representation + ntt_negacyclic_harvey(noise.get() + i * coeff_count, ntt_tables[i]); + + add_poly_coeffmod( + noise.get() + i * coeff_count, plaintext.data() + i * coeff_count, coeff_count, coeff_modulus[i], + plaintext.data() + i * coeff_count); + } + + // Set destination parameters as in encrypted + plaintext.parms_id() = encrypted.parms_id(); + plaintext.scale() = encrypted.scale(); + return plaintext; +} + +Server::Server(GlobalState global_state) : _gs(move(global_state)) {}; + +void Server::accumulate_partial_public_keys(gsl::span<const Ciphertext> partial_pub_keys) { + // sum only the first poly of the ciphertexts + // the second poly is always the same, see GlobalState.a + Ciphertext sum = sum_first_polys(context(), partial_pub_keys); + _public_key.data() = sum; + assert(is_valid_for(_public_key, context())); +} + +Ciphertext Server::sum_data(vector<Ciphertext>&& data) const { + Evaluator e(_gs.context); + Ciphertext result; + e.add_many(data, result); + return result; +} + +vector<double> Server::average(const Ciphertext& encrypted_sum, gsl::span<const Plaintext> partial_decryptions) const { + // the partial decryptions were of the form c1*s[i]. we need c0 + sum(c1+s[i]) + // so we need to add c0 once here. + + // FIXME: this copies encrypted_sum, which is unnecessary + uint64_t num_coeffs = util::mul_safe(encrypted_sum.poly_modulus_degree(), encrypted_sum.coeff_modulus_size()); + gsl::span<const Plaintext::pt_coeff_type> es_data(encrypted_sum.data(0), num_coeffs); + Plaintext c0(es_data); + c0.parms_id() = context().first_parms_id(); + c0.scale() = encrypted_sum.scale(); + + sum_first_polys_inplace(_gs.context, c0, partial_decryptions); // c0 + sum(c1+s[i]) + + // decode sum + size_t slot_count = context().first_context_data()->parms().poly_modulus_degree() >> 1; + CKKSEncoder encoder(context()); + vector<double> result(slot_count * 2, 0.0); + gsl::span<complex<double>> result_destination(reinterpret_cast<complex<double>*>(result.data()), slot_count); + encoder.decode(c0, result_destination); + + // divide by N for average + for (double& x : result) { + x /= static_cast<double>(partial_decryptions.size()); + } + return result; +} + diff --git a/src/main/cpp/he/libhe.h b/src/main/cpp/he/libhe.h new file mode 100644 index 00000000000..25774a80a43 --- /dev/null +++ b/src/main/cpp/he/libhe.h @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef LIBHE_H +#define LIBHE_H + +#include <cassert> +#include <algorithm> +#include <optional> +#include <gsl/span> + +#include "seal/seal.h" +#include "seal/util/common.h" +#include "seal/util/rlwe.h" +#include "seal/util/polyarithsmallmod.h" + +using namespace std; +using namespace seal; + +class RawPolynomData { + vector<uint64_t > _data; + size_t _size; + +public: + explicit RawPolynomData(const SEALContext& context); + + SEAL_NODISCARD inline const size_t& size() const { return _size; }; + SEAL_NODISCARD inline uint64_t* data() { return _data.data(); }; + SEAL_NODISCARD inline gsl::span<uint64_t > data_span() { return { data(), size() }; }; + + void set_data(vector<uint64_t >& data); +}; + +gsl::span<Ciphertext::ct_coeff_type > data_span(Ciphertext& c, size_t n); + +RawPolynomData generate_a(const SEALContext& context); + +EncryptionParameters generateParameters(); + +size_t get_slot_count(const SEALContext& ctx); + +// returns a vector filled with random double values between 0 and 1 +vector<double> random_plaintext_data(size_t count); + +struct GlobalState { + SEALContext context; + RawPolynomData a; + double scale; + + explicit GlobalState(double _scale); +}; + +class Client { + GlobalState _gs; + CKKSEncoder _encoder; + SecretKey _partial_secret_key; + PublicKey _partial_public_key; + std::optional<PublicKey> _public_key = std::nullopt; + std::unique_ptr<Encryptor> _encryptor = nullptr; + + SEAL_NODISCARD static PublicKey generate_partial_public_key(const SecretKey &secret_key, const SEALContext &context, RawPolynomData& a); + +public: + explicit Client(GlobalState global_state); + + SEAL_NODISCARD inline const SEALContext& context() const { return _gs.context; }; + SEAL_NODISCARD inline const PublicKey& partial_public_key() const { return _partial_public_key; }; + SEAL_NODISCARD inline const CKKSEncoder& encoder() const { return _encoder; }; + SEAL_NODISCARD inline CKKSEncoder& encoder() { return _encoder; }; + SEAL_NODISCARD inline const Encryptor& encryptor() const { assert(_encryptor != nullptr); return *_encryptor; }; + SEAL_NODISCARD inline const PublicKey& public_key() { return *_public_key; }; + inline void set_public_key(const PublicKey& pk) { _public_key = make_optional(pk); }; + + Ciphertext encrypted_data(gsl::span<const double> plain_data); + + Plaintext partial_decryption(const Ciphertext& encrypted); +}; + +// adds b to a in place +template<typename T> void sum_first_poly_inplace(const SEALContext& context, T& a, const T& b) { + auto &context_data = *context.get_context_data(a.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_modulus_size = coeff_modulus.size(); + + // by dereferencing we get only the first poly + auto summand_iter = *util::ConstPolyIter(b.data(), coeff_count, coeff_modulus_size); + auto sum_iter = *util::ConstPolyIter(a.data(), coeff_count, coeff_modulus_size); + auto result_iter = *util::PolyIter(a.data(), coeff_count, coeff_modulus_size); + // see Evaluator::add_inplace + util::add_poly_coeffmod(sum_iter, summand_iter, coeff_modulus_size, coeff_modulus, result_iter); +} + +// This function adds the first polys in summands to sum (either Ciphertext or Plaintext). +template<typename T> T sum_first_polys_inplace(const SEALContext& context, T& sum, gsl::span<const T> summands) { + for (size_t i = 0; i < summands.size(); i++) { + sum_first_poly_inplace(context, sum, summands[i]); + } + return sum; +} + +// This function sums the first polys in summands (either Ciphertext or Plaintext). +template<typename T> T sum_first_polys(const SEALContext& context, gsl::span<const T> summands) { + T sum = summands[0]; + sum_first_polys_inplace(context, sum, gsl::span(&summands.data()[1], summands.size() - 1)); + return sum; +} + +class Server { + GlobalState _gs; + PublicKey _public_key; + +public: + explicit Server(GlobalState global_state); + + SEAL_NODISCARD inline RawPolynomData& a() { return _gs.a; }; + SEAL_NODISCARD inline const SEALContext& context() const { return _gs.context; }; + SEAL_NODISCARD inline const PublicKey& public_key() const { return _public_key; }; + + void accumulate_partial_public_keys(gsl::span<const Ciphertext> partial_pub_keys); + + Ciphertext sum_data(vector<Ciphertext>&& data) const; + + vector<double> average(const Ciphertext& encrypted_sum, gsl::span<const Plaintext> partial_decryptions) const; +}; + +#endif //LIBHE_H diff --git a/src/main/cpp/lib/libhe-Linux-x86_64.so b/src/main/cpp/lib/libhe-Linux-x86_64.so new file mode 100644 index 00000000000..5d55922788b Binary files /dev/null and b/src/main/cpp/lib/libhe-Linux-x86_64.so differ diff --git a/src/main/cpp/systemds.cpp b/src/main/cpp/systemds.cpp index bed1d42f578..86ac053ad10 100644 --- a/src/main/cpp/systemds.cpp +++ b/src/main/cpp/systemds.cpp @@ -17,6 +17,12 @@ * under the License. */ +#ifdef _WIN32 +#include <winsock.h> +#else +#include <arpa/inet.h> +#endif + #include "common.h" #include "libmatrixdnn.h" #include "libmatrixmult.h" @@ -248,4 +254,4 @@ JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_conv2dBackwardF RELEASE_INPUT_ARRAY(env, dout, doutPtr, numThreads); RELEASE_ARRAY(env, ret, retPtr, numThreads); return static_cast<jlong>(nnz); -} +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index b6ef98abc72..3dfad3413e5 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -44,7 +44,7 @@ public enum ExecType { CP, CP_FILE, SPARK, GPU, FED, INVALID } * Data types (tensor, matrix, scalar, frame, object, unknown). */ public enum DataType { - TENSOR, MATRIX, SCALAR, FRAME, LIST, UNKNOWN; + TENSOR, MATRIX, SCALAR, FRAME, LIST, ENCRYPTED_CIPHER, ENCRYPTED_PLAIN, UNKNOWN; public boolean isMatrix() { return this == MATRIX; diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index a6bcc2dac44..56275fd3b93 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -320,7 +320,8 @@ private void validateParamserv(DataIdentifier output, boolean conditional) { Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN, Statement.PS_VAL_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS, Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING, - Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED, Statement.PS_NBATCHES, Statement.PS_MODELAVG); + Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED, Statement.PS_NBATCHES, + Statement.PS_MODELAVG, Statement.PS_HE); checkInvalidParameters(getOpCode(), getVarParams(), valid); // check existence and correctness of parameters diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java index 995a1e23309..d22a5401806 100644 --- a/src/main/java/org/apache/sysds/parser/Statement.java +++ b/src/main/java/org/apache/sysds/parser/Statement.java @@ -76,6 +76,7 @@ public abstract class Statement implements ParseInfo public static final String PS_SEED = "seed"; public static final String PS_MODELAVG = "modelAvg"; public static final String PS_NBATCHES = "nbatches"; + public static final String PS_HE = "he"; public enum PSModeType { FEDERATED, LOCAL, REMOTE_SPARK } @@ -124,7 +125,6 @@ public enum PSCheckpointing { public static final String PS_FED_AGGREGATION_FNAME = "1701-NCC-aggregation_fname"; public static final String PS_FED_MODEL_VARID = "1701-NCC-model_varid"; - public abstract boolean controlStatement(); public abstract VariableSet variablesRead(); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java index e398abd32b0..3a1c78cb1d1 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java @@ -38,6 +38,7 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.controlprogram.caching.TensorObject; import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.cp.CPOperand; @@ -84,7 +85,9 @@ public class ExecutionContext { //parfor temporary functions (created by eval) protected Set<String> _fnNames; - + + protected SEALClient _seal_client; + /** * List of {@link GPUContext}s owned by this {@link ExecutionContext} */ @@ -152,6 +155,14 @@ public long getTID() { return _tid; } + public void setSealClient(SEALClient seal_client) { + _seal_client = seal_client; + } + + public SEALClient getSealClient() { + return _seal_client; + } + /** * Get the i-th GPUContext * @param index index of the GPUContext diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java index 74e113ba020..23665c1dc02 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java @@ -29,6 +29,8 @@ import javax.net.ssl.SSLException; +import io.netty.handler.codec.serialization.ClassResolvers; +import io.netty.handler.codec.serialization.ObjectDecoder; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types; @@ -46,6 +48,7 @@ import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; +import org.apache.sysds.runtime.controlprogram.paramserv.NetworkTrafficCounter; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; @@ -171,7 +174,26 @@ public synchronized static Future<FederatedResponse> executeFederatedOperation(I final DataRequestHandler handler = new DataRequestHandler(); // Client Netty - b.handler(createChannel(address, handler)); + b.group(workerGroup).channel(NioSocketChannel.class).handler(new ChannelInitializer<SocketChannel>() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ChannelPipeline cp = ch.pipeline(); + if(ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.USE_SSL_FEDERATED_COMMUNICATION)) { + cp.addLast(SslConstructor().context + .newHandler(ch.alloc(), address.getAddress().getHostAddress(), address.getPort())); + } + final int timeout = ConfigurationManager.getFederatedTimeout(); + if(timeout > -1) + cp.addLast("timeout",new ReadTimeoutHandler(timeout)); + + cp.addLast("NetworkTrafficCounter", new NetworkTrafficCounter(FederatedStatistics::logServerTraffic)); + cp.addLast("ObjectDecoder", + new ObjectDecoder(Integer.MAX_VALUE, + ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader()))); + cp.addLast("FederatedOperationHandler", handler); + cp.addLast("ObjectEncoder", new ObjectEncoder()); + } + }); ChannelFuture f = b.connect(address).sync(); Promise<FederatedResponse> promise = f.channel().eventLoop().newPromise(); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java index de56a1a52e0..77ffb7f847f 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java @@ -25,6 +25,7 @@ import org.apache.log4j.Logger; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler; public class FederatedLocalData extends FederatedData { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java index d95b02afd21..56636e1f7d3 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java @@ -73,6 +73,8 @@ public class FederatedStatistics { private static final LongAdder transferredMatrixBytes = new LongAdder(); private static final LongAdder transferredFrameBytes = new LongAdder(); private static final LongAdder asyncPrefetchCount = new LongAdder(); + private static final LongAdder bytesSent = new LongAdder(); + private static final LongAdder bytesReceived = new LongAdder(); // stats on the federated worker itself private static final LongAdder fedLookupTableGetCount = new LongAdder(); @@ -84,6 +86,18 @@ public class FederatedStatistics { private static final LongAdder fedPutLineageItems = new LongAdder(); private static final LongAdder fedSerializationReuseCount = new LongAdder(); private static final LongAdder fedSerializationReuseBytes = new LongAdder(); + private static final LongAdder fedBytesSent = new LongAdder(); + private static final LongAdder fedBytesReceived = new LongAdder(); + + public static void logServerTraffic(long read, long written) { + bytesReceived.add(read); + bytesSent.add(written); + } + + public static void logWorkerTraffic(long read, long written) { + fedBytesReceived.add(read); + fedBytesSent.add(written); + } public static synchronized void incFederated(RequestType rqt, List<Object> data){ switch (rqt) { @@ -164,6 +178,10 @@ public static void reset() { fedPutLineageItems.reset(); fedSerializationReuseCount.reset(); fedSerializationReuseBytes.reset(); + bytesSent.reset(); + bytesReceived.reset(); + fedBytesSent.reset(); + fedBytesReceived.reset(); } public static String displayFedIOExecStatistics() { @@ -194,6 +212,19 @@ public static String displayFedIOExecStatistics() { return ""; } + public static String displayNetworkTrafficStatistics() { + return "Server I/O bytes (read/written):\t" + + bytesReceived.longValue() + + "/" + + bytesSent.longValue() + + "\n" + + "Worker I/O bytes (read/written):\t" + + fedBytesReceived.longValue() + + "/" + + fedBytesSent.longValue() + + "\n"; + } + public static void registerFedWorker(String host, int port) { _fedWorkerAddresses.add(new ImmutablePair<>(host, Integer.valueOf(port))); @@ -232,6 +263,7 @@ public static String displayStatistics(FedStatsCollection fedStats, int numHeavy sb.append(displayLinCacheStats(fedStats.linCacheStats)); sb.append(displayMultiTenantStats(fedStats.mtStats)); sb.append(displayHeavyHitters(fedStats.heavyHitters, numHeavyHitters)); + sb.append(displayNetworkTrafficStatistics()); return sb.toString(); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java index a41f656524d..fc93bb603ca 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java @@ -27,6 +27,8 @@ import javax.net.ssl.SSLException; +import io.netty.handler.codec.serialization.ClassResolvers; +import io.netty.handler.codec.serialization.ObjectDecoder; import org.apache.log4j.Logger; import org.apache.sysds.api.DMLScript; import org.apache.sysds.conf.ConfigurationManager; @@ -53,6 +55,8 @@ import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; +import org.apache.sysds.runtime.controlprogram.paramserv.NetworkTrafficCounter; +import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; public class FederatedWorker { protected static Logger log = Logger.getLogger(FederatedWorker.class); @@ -62,6 +66,7 @@ public class FederatedWorker { private final FederatedReadCache _frc; private final FederatedWorkloadAnalyzer _fan; private final boolean _debug; + private Timing networkTimer = new Timing(); public FederatedWorker(int port, boolean debug) { _flt = new FederatedLookupTable(); @@ -90,15 +95,30 @@ private void run() { new SynchronousQueue<Runnable>(true)); NioEventLoopGroup workerGroup = new NioEventLoopGroup(EVENT_LOOP_THREADS, workerTPE); - final boolean ssl = ConfigurationManager.isFederatedSSL(); try { + SelfSignedCertificate cert = new SelfSignedCertificate(); + final SslContext cont2 = SslContextBuilder.forServer(cert.certificate(), cert.privateKey()).build(); final ServerBootstrap b = new ServerBootstrap(); - b.group(bossGroup, workerGroup); - b.channel(NioServerSocketChannel.class); - b.childHandler(createChannel(ssl)); - b.option(ChannelOption.SO_BACKLOG, 128); - b.childOption(ChannelOption.SO_KEEPALIVE, true); - + b.group(bossGroup, workerGroup) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer<SocketChannel>() { + @Override + public void initChannel(SocketChannel ch) { + ChannelPipeline cp = ch.pipeline(); + if(ConfigurationManager.getDMLConfig() + .getBooleanValue(DMLConfig.USE_SSL_FEDERATED_COMMUNICATION)) { + cp.addLast(cont2.newHandler(ch.alloc())); + } + cp.addLast("NetworkTrafficCounter", new NetworkTrafficCounter(FederatedStatistics::logWorkerTraffic)); + cp.addLast("ObjectDecoder", + new ObjectDecoder(Integer.MAX_VALUE, + ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader()))); + cp.addLast("ObjectEncoder", new ObjectEncoder()); + cp.addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_flt, _frc, networkTimer)); + } + }) + .option(ChannelOption.SO_BACKLOG, 128) + .childOption(ChannelOption.SO_KEEPALIVE, true); log.info("Starting Federated Worker server at port: " + _port); ChannelFuture f = b.bind(_port).sync(); log.info("Started Federated Worker at port: " + _port); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java index 4c90c74b1b9..2f2b206311d 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java @@ -50,6 +50,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.Instruction.IType; import org.apache.sysds.runtime.instructions.InstructionParser; @@ -73,6 +74,7 @@ import org.apache.sysds.runtime.privacy.DMLPrivacyException; import org.apache.sysds.runtime.privacy.PrivacyMonitor; import org.apache.sysds.utils.Statistics; +import org.apache.sysds.utils.stats.ParamServStatistics; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -93,7 +95,9 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { private final FederatedReadCache _frc; /** Federated workload analyzer */ - private final FederatedWorkloadAnalyzer _fan; + private FederatedWorkloadAnalyzer _fan; + + private Timing _timing = null; /** * Create a Federated Worker Handler. @@ -111,6 +115,16 @@ public FederatedWorkerHandler(FederatedLookupTable flt, FederatedReadCache frc, _fan = fan; } + public FederatedWorkerHandler(FederatedLookupTable flt, FederatedReadCache frc, Timing timing) { + this(flt, frc); + _timing = timing; + } + + public FederatedWorkerHandler(FederatedLookupTable flt, FederatedReadCache frc) { + _flt = flt; + _frc = frc; + } + @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { ctx.writeAndFlush(createResponse(msg, ctx.channel().remoteAddress())) @@ -122,6 +136,14 @@ protected FederatedResponse createResponse(Object msg) { } private FederatedResponse createResponse(Object msg, SocketAddress remoteAddress) { + try { + if (_timing != null) { + ParamServStatistics.accFedNetworkTime((long) _timing.stop()); + } + } catch (RuntimeException ignored) { + // ignore timing if it wasn't started yet + } + String host; if(remoteAddress instanceof InetSocketAddress) { host = ((InetSocketAddress) remoteAddress).getHostString(); @@ -135,7 +157,11 @@ else if(remoteAddress instanceof SocketAddress) { host = FederatedLookupTable.NOHOST; } - return createResponse(msg, host); + FederatedResponse res = createResponse(msg, host); + if (_timing != null) { + _timing.start(); + } + return res; } private FederatedResponse createResponse(Object msg, String remoteHost) { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java index 004e35b5718..54d778486a7 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java @@ -34,24 +34,15 @@ import org.apache.sysds.runtime.controlprogram.ProgramBlock; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.federated.FederatedData; -import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.*; import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType; -import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; -import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; -import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; -import org.apache.sysds.runtime.instructions.cp.BooleanObject; -import org.apache.sysds.runtime.instructions.cp.CPOperand; -import org.apache.sysds.runtime.instructions.cp.Data; -import org.apache.sysds.runtime.instructions.cp.DoubleObject; -import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction; -import org.apache.sysds.runtime.instructions.cp.IntObject; -import org.apache.sysds.runtime.instructions.cp.ListObject; -import org.apache.sysds.runtime.instructions.cp.StringObject; +import org.apache.sysds.runtime.instructions.cp.*; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; import org.apache.sysds.runtime.lineage.LineageItem; @@ -59,11 +50,13 @@ import org.apache.sysds.utils.stats.ParamServStatistics; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.concurrent.Callable; import java.util.concurrent.Future; import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.apache.sysds.runtime.util.ProgramConverter.*; @@ -83,20 +76,23 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void> private final boolean _weighting; private double _weightingFactor = 1; private boolean _cycleStartAt0 = false; + private boolean _use_homomorphic_encryption = false; + private PublicKey _partial_public_key; public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, PSRuntimeBalancing runtimeBalancing, boolean weighting, int epochs, long batchSize, - int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps, int nbatches, boolean modelAvg) + int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps, int nbatches, boolean modelAvg, boolean use_homomorphic_encryption) { super(workerID, updFunc, freq, epochs, batchSize, ec, ps, nbatches, modelAvg); _numBatchesPerEpoch = numBatchesPerGlobalEpoch; _runtimeBalancing = runtimeBalancing; - _weighting = weighting; + _weighting = weighting && (!use_homomorphic_encryption); // FIXME: this disables weighting in favor of homomorphic encryption _numBatchesPerNbatch = nbatches; // generate the ID for the model _modelVarID = FederationUtils.getNextFedDataID(); - _modelAvg = modelAvg; + _modelAvg = _use_homomorphic_encryption || modelAvg; // we always have to use modelAvg when using homomorphic encryption + _use_homomorphic_encryption = use_homomorphic_encryption; } /** @@ -106,6 +102,9 @@ public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFreque */ public void setup(double weightingFactor) { incWorkerNumber(); + if (_use_homomorphic_encryption) { + ((HEParamServer)_ps).registerThread(_workerID, this); + } // prepare features and labels _featuresData = _features.getFedMapping().getFederatedData()[0]; @@ -160,21 +159,43 @@ public void setup(double weightingFactor) { PROG_END); // write program and meta data to worker - Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation( - new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), - new SetupFederatedWorker(_batchSize, dataSize, _possibleBatchesPerLocalEpoch, - programSerialized, _inst.getNamespace(), _inst.getFunctionName(), - _ps.getAggInst().getFunctionName(), _ec.getListObject("hyperparams"), - _modelVarID, _nbatches, _modelAvg))); + Future<FederatedResponse> udfResponse; + + final SetupFederatedWorker udf; + if (_use_homomorphic_encryption) { + byte[] a = ((HEParamServer)_ps).generateA(); + // generate pk[i] on each client and return it + udf = new SetupHEFederatedWorker(a); + } else { + udf = new SetupFederatedWorker(); + } + udf.setParams(_batchSize, dataSize, _possibleBatchesPerLocalEpoch, + programSerialized, _inst.getNamespace(), _inst.getFunctionName(), + _ps.getAggInst().getFunctionName(), _ec.getListObject("hyperparams"), + _modelVarID, _nbatches, _use_homomorphic_encryption || _modelAvg); + + udfResponse = _featuresData.executeFederatedOperation( + new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), udf)); + + FederatedResponse response; try { - FederatedResponse response = udfResponse.get(); + response = udfResponse.get(); if(!response.isSuccessful()) throw new DMLRuntimeException("FederatedLocalPSThread: Setup UDF failed"); + } catch(Exception e) { throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Setup UDF" + e.getMessage()); } + if (_use_homomorphic_encryption) { + try { + _partial_public_key = (PublicKey) response.getData()[0]; + } + catch (Exception e) { + throw new DMLRuntimeException("FederatedLocalPSThread: HE Setup UDF didn't return an object"); + } + } } /** @@ -196,29 +217,33 @@ public void teardown() { throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Teardown UDF" + e.getMessage()); } } - + /** * Setup UDF executed on the federated worker */ private static class SetupFederatedWorker extends FederatedUDF { private static final long serialVersionUID = -3148991224792675607L; - private final long _batchSize; - private final long _dataSize; - private final int _possibleBatchesPerLocalEpoch; - private final String _programString; - private final String _namespace; - private final String _gradientsFunctionName; - private final String _aggregationFunctionName; - private final ListObject _hyperParams; - private final long _modelVarID; - private final boolean _modelAvg; - private final int _nbatches; - - protected SetupFederatedWorker(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch, - String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName, - ListObject hyperParams, long modelVarID, int nbatches, boolean modelAvg) + private long _batchSize; + private long _dataSize; + private int _possibleBatchesPerLocalEpoch; + private String _programString; + private String _namespace; + private String _gradientsFunctionName; + private String _aggregationFunctionName; + private ListObject _hyperParams; + private long _modelVarID; + private boolean _modelAvg; + private int _nbatches; + private boolean _params_set = false; + + protected SetupFederatedWorker() { super(new long[]{}); + } + + public void setParams(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch, + String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName, + ListObject hyperParams, long modelVarID, int nbatches, boolean modelAvg) { _batchSize = batchSize; _dataSize = dataSize; _possibleBatchesPerLocalEpoch = possibleBatchesPerLocalEpoch; @@ -230,10 +255,15 @@ protected SetupFederatedWorker(long batchSize, long dataSize, int possibleBatche _modelVarID = modelVarID; _modelAvg = modelAvg; _nbatches = nbatches; + _params_set = true; } @Override public FederatedResponse execute(ExecutionContext ec, Data... data) { + if (!_params_set) { + return new FederatedResponse(FederatedResponse.ResponseType.ERROR, "params were not set"); + } + // parse and set program ec.setProgram(ProgramConverter.parseProgram(_programString, 0)); @@ -258,9 +288,59 @@ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { } } - /** + private static class SetupHEFederatedWorker extends SetupFederatedWorker { + private static final long serialVersionUID = 9128347291804980123L; + + byte[] _partial_pubkey_a; + + protected SetupHEFederatedWorker(byte[] partial_pubkey_a) { + // delegate everything to parent class. set modelAvg to true, as it is the only supported case + super(); + _partial_pubkey_a = partial_pubkey_a; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + // TODO: set other CKKS parameters + // TODO generate partial public key + NativeHEHelper.initialize(); + + SEALClient sc = new SEALClient(_partial_pubkey_a); + ec.setSealClient(sc); + PublicKey partial_pubkey = sc.generatePartialPublicKey(); + + FederatedResponse res = super.execute(ec, data); + if (!res.isSuccessful()) { + return res; + } + + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, partial_pubkey); + } + } + /** * Teardown UDF executed on the federated worker */ + private static class SetPublicKeyFederatedWorker extends FederatedUDF { + private static final long serialVersionUID = -1536502123123318969L; + private final PublicKey _public_key; + + protected SetPublicKeyFederatedWorker(PublicKey public_key) { + super(new long[]{}); + _public_key = public_key; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + ec.getSealClient().setPublicKey(_public_key); + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); + } + + @Override + public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { + return null; + } + } + private static class TeardownFederatedWorker extends FederatedUDF { private static final long serialVersionUID = -153650281873318969L; @@ -298,6 +378,7 @@ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { @Override public Void call() throws Exception { try { + Timing tTotal = new Timing(true); switch (_freq) { case BATCH: computeWithBatchUpdates(); @@ -324,6 +405,7 @@ protected ListObject pullModel() { } protected void weightAndPushGradients(ListObject gradients) { + assert (!(_weighting && _use_homomorphic_encryption)) : "weights and homomorphic encryption are not supported together"; // scale gradients - must only include MatrixObjects if(_weighting && _weightingFactor != 1) { Timing tWeighting = DMLScript.STATISTICS ? new Timing(true) : null; @@ -354,11 +436,17 @@ protected void computeWithBatchUpdates() { int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch); ListObject model = pullModel(); ListObject gradients = computeGradientsForNBatches(model, 1, localStartBatchNum); - if (_modelAvg) + + Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null; + if (_modelAvg && !_use_homomorphic_encryption) + // we can't call the agg fn if we use HE, because it is implemented homomorphically in SEALServer::aggregateCiphertexts model = _ps.updateLocalModel(_ec, gradients, model); else ParamservUtils.cleanupListObject(model); - weightAndPushGradients(_modelAvg ? model : gradients); + weightAndPushGradients((_modelAvg && !_use_homomorphic_encryption) ? model : gradients); + if (tAgg != null) { + ParamServStatistics.accFedAggregation((long)tAgg.stop()); + } ParamservUtils.cleanupListObject(gradients); } } @@ -377,7 +465,13 @@ protected void computeWithNBatchUpdates() { currentLocalBatchNumber = currentLocalBatchNumber + _numBatchesPerNbatch; ListObject model = pullModel(); ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerNbatch, localStartBatchNum, true); + + Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null; weightAndPushGradients(gradients); + if (tAgg != null) { + ParamServStatistics.accFedAggregation((long)tAgg.stop()); + } + ParamservUtils.cleanupListObject(model); ParamservUtils.cleanupListObject(gradients); } @@ -394,7 +488,13 @@ protected void computeWithEpochUpdates() { // Pull the global parameters from ps ListObject model = pullModel(); ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, true); + + Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null; weightAndPushGradients(gradients); + if (tAgg != null) { + ParamServStatistics.accFedAggregation((long)tAgg.stop()); + } + ParamservUtils.cleanupListObject(model); ParamservUtils.cleanupListObject(gradients); } @@ -431,11 +531,16 @@ protected ListObject computeGradientsForNBatches(ListObject model, } // create and execute the udf on the remote worker + Object udf; + if (_use_homomorphic_encryption) { + udf = new HEComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID()}, + new long[]{_modelVarID}, numBatchesToCompute, localUpdate, localStartBatchNum); + } else { + udf = new federatedComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), + _modelVarID}, numBatchesToCompute, localUpdate, localStartBatchNum); + } Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation( - new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), - new federatedComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), - _modelVarID}, numBatchesToCompute, localUpdate, localStartBatchNum) - )); + new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), udf)); try { Object[] responseData = udfResponse.get().getData(); @@ -444,6 +549,7 @@ protected ListObject computeGradientsForNBatches(ListObject model, long workerComputing = ((DoubleObject) responseData[1]).getLongValue(); ParamServStatistics.accFedWorkerComputing(workerComputing); ParamServStatistics.accFedCommunicationTime(total - workerComputing); + ParamServStatistics.accFedNetworkTime(total); } return (ListObject) responseData[0]; } @@ -492,12 +598,12 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { ArrayList<DataIdentifier> inputs = func.getInputParams(); ArrayList<DataIdentifier> outputs = func.getOutputParams(); CPOperand[] boundInputs = inputs.stream() - .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) - .toArray(CPOperand[]::new); + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); + .collect(Collectors.toCollection(ArrayList::new)); Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunc, - opt, boundInputs, func.getInputParamNames(), outputNames, "gradient function"); + opt, boundInputs, func.getInputParamNames(), outputNames, "gradient function"); DataIdentifier gradientsOutput = outputs.get(0); // recreate aggregation instruction and output if needed @@ -508,12 +614,12 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { inputs = func.getInputParams(); outputs = func.getOutputParams(); boundInputs = inputs.stream() - .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) - .toArray(CPOperand[]::new); + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); outputNames = outputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); + .collect(Collectors.toCollection(ArrayList::new)); aggregationInstruction = new FunctionCallCPInstruction(namespace, aggFunc, - opt, boundInputs, func.getInputParamNames(), outputNames, "aggregation function"); + opt, boundInputs, func.getInputParamNames(), outputNames, "aggregation function"); aggregationOutput = outputs.get(0); } ListObject accGradients = null; @@ -540,7 +646,7 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // accrue the computed gradients - In the single batch case this is just a list copy // is this equivalent for momentum based and AMS prob? accGradients = modelAvg ? null : - ParamservUtils.accrueGradients(accGradients, gradients, false); + ParamservUtils.accrueGradients(accGradients, gradients, false); // update the local model with gradients if needed // FIXME ensure that with modelAvg we always update the model @@ -564,11 +670,12 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // model clean up ParamservUtils.cleanupListObject(ec, ec.getVariable(Statement.PS_FED_MODEL_VARID).toString()); // TODO double check cleanup gradients and models - + // stop timing DoubleObject gradientsTime = new DoubleObject(tGradients.stop()); + ParamServStatistics.accGradientComputeTime(gradientsTime.getLongValue()); return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, - new Object[]{modelAvg ? model : accGradients, gradientsTime}); + new Object[]{modelAvg ? model : accGradients, gradientsTime}); } @Override @@ -577,6 +684,102 @@ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { } } + + /** + * This wraps federatedComputeGradientsForNBatches and adds encryption + */ + private static class HEComputeGradientsForNBatches extends federatedComputeGradientsForNBatches { + private static final long serialVersionUID = -3535901512348794852L; + private final long[] _deferredIds; + + protected HEComputeGradientsForNBatches(long[] deferredIds, long[] inIDs, int numBatchesToCompute, boolean localUpdate, int localStartBatchNum) { + super(inIDs, numBatchesToCompute, localUpdate, localStartBatchNum); + _deferredIds = deferredIds; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data_without_deferred) { + Timing tTotal = new Timing(true); + // add features and gradients to data + Data[] deferred_inputs = Arrays.stream(_deferredIds).mapToObj(id -> ec.getVariable(String.valueOf(id))).toArray(Data[]::new); + Data[] data = Arrays.copyOf(deferred_inputs, deferred_inputs.length + data_without_deferred.length); + System.arraycopy(data_without_deferred, 0, data, deferred_inputs.length, data_without_deferred.length); + FederatedResponse res = super.execute(ec, data); + + if (!res.isSuccessful()) { + return res; + } + + // encrypt model with SEAL + try { + Timing tEncrypt = DMLScript.STATISTICS ? new Timing(true) : null; + + ListObject model = (ListObject) res.getData()[0]; + ListObject encrypted_model = new ListObject(model); + IntStream.range(0, model.getLength()).forEach(matrix_idx -> { + CiphertextMatrix encrypted_matrix = ec.getSealClient().encrypt((MatrixObject) model.getData(matrix_idx)); + encrypted_model.set(matrix_idx, encrypted_matrix); + }); + + // overwrite model with encryption + res.getData()[0] = encrypted_model; + + if (tEncrypt != null) { + ParamServStatistics.accHEEncryptionTime((long)tEncrypt.stop()); + } + + // stop timing + DoubleObject gradientsTime = new DoubleObject(tTotal.stop()); + res.getData()[1] = gradientsTime; + } catch (Exception e) { + return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new Object[] { e }); + } + return res; + } + } + + private static class HEComputePartialDecryption extends FederatedUDF { + private static final long serialVersionUID = -4535098129348794852L; + private final CiphertextMatrix[] _encrypted_sum; + + protected HEComputePartialDecryption(CiphertextMatrix[] encrypted_sum) { + super(new long[]{}); + _encrypted_sum = encrypted_sum; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + Timing tPartialDecrypt = DMLScript.STATISTICS ? new Timing(true) : null; + PlaintextMatrix[] result = new PlaintextMatrix[_encrypted_sum.length]; + IntStream.range(0, result.length).forEach(i -> { + result[i] = ec.getSealClient().partiallyDecrypt(_encrypted_sum[i]); + }); + if (tPartialDecrypt != null) { + ParamServStatistics.accHEPartialDecryptionTime((long)tPartialDecrypt.stop()); + } + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, result); + } + + @Override + public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { + return null; + } + } + + + public PlaintextMatrix[] getPartialDecryption(CiphertextMatrix[] encrypted_sum) { + Object udf = new HEComputePartialDecryption(encrypted_sum); + Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation( + new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), udf)); + + try { + Object[] responseData = udfResponse.get().getData(); + return (PlaintextMatrix[]) responseData; + } catch(Exception e) { + throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute UDF" + e.getMessage()); + } + } + // Statistics methods protected void accFedPSGradientWeightingTime(Timing time) { if (DMLScript.STATISTICS && time != null) @@ -608,4 +811,24 @@ protected void accBatchIndexingTime(Timing time) { protected void accGradientComputeTime(Timing time) { throw new NotImplementedException(); } + + public PublicKey getPartialPublicKey() { + return _partial_public_key; + } + + public void setPublicKey(PublicKey public_key) { + Future<FederatedResponse> res = _featuresData.executeFederatedOperation( + new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), + new SetPublicKeyFederatedWorker(public_key))); + + try { + FederatedResponse response = res.get(); + if(!response.isSuccessful()) + throw new DMLRuntimeException("FederatedLocalPSThread: SetPublicKey UDF failed"); + + } + catch(Exception e) { + throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Public Key Setup UDF" + e.getMessage()); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java new file mode 100644 index 00000000000..577bf6c8205 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.paramserv; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.parser.Statement; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALServer; +import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; +import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix; +import org.apache.sysds.runtime.instructions.cp.ListObject; +import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix; +import org.apache.sysds.utils.NativeHelper; +import org.apache.sysds.utils.stats.ParamServStatistics; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * This class implements Homomorphic Encryption (HE) for LocalParamServer. It only supports modelAvg=true. + */ +public class HEParamServer extends LocalParamServer { + private int _thread_counter = 0; + private final List<FederatedPSControlThread> _threads; + private final List<Object> _result_buffer; // one per thread + private Object _result; + private final SEALServer _seal_server; + + public static HEParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType, + Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch, + MatrixObject valFeatures, MatrixObject valLabels, int nbatches) + { + NativeHEHelper.initialize(); + return new HEParamServer(model, aggFunc, updateType, freq, ec, + workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches); + } + + private HEParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, + Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch, + MatrixObject valFeatures, MatrixObject valLabels, int nbatches) + { + super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, true); + + _seal_server = new SEALServer(); + + _threads = Collections.synchronizedList(new ArrayList<>(workerNum)); + for (int i = 0; i < getNumWorkers(); i++) { + _threads.add(null); + } + + _result_buffer = new ArrayList<>(workerNum); + resetResultBuffer(); + } + + public void registerThread(int thread_id, FederatedPSControlThread thread) { + _threads.set(thread_id, thread); + } + + private synchronized void resetResultBuffer() { + _result_buffer.clear(); + for (int i = 0; i < getNumWorkers(); i++) { + _result_buffer.add(null); + } + } + + public byte[] generateA() { + return _seal_server.generateA(); + } + + public PublicKey aggregatePartialPublicKeys(PublicKey[] partial_public_keys) { + return _seal_server.aggregatePartialPublicKeys(partial_public_keys); + } + + /** + * this method collects all T Objects from each worker into a list and then calls f once on this list to produce + * another T, which it returns. + */ + private synchronized <T,U> U collectAndDo(int workerId, T obj, Function<List<T>, U> f) { + _result_buffer.set(workerId, obj); + _thread_counter++; + + if (_thread_counter == getNumWorkers()) { + List<T> buf = _result_buffer.stream().map(x -> (T)x).collect(Collectors.toList()); + _result = f.apply(buf); + resetResultBuffer(); + _thread_counter = 0; + notifyAll(); + } else { + try { + wait(); + } catch (InterruptedException i) { + throw new RuntimeException("thread interrupted"); + } + } + + return (U) _result; + } + + private CiphertextMatrix[] homomorphicAggregation(List<ListObject> encrypted_models) { + Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null; + CiphertextMatrix[] result = new CiphertextMatrix[encrypted_models.get(0).getLength()]; + IntStream.range(0, encrypted_models.get(0).getLength()).forEach(matrix_idx -> { + CiphertextMatrix[] summands = new CiphertextMatrix[encrypted_models.size()]; + for (int i = 0; i < encrypted_models.size(); i++) { + summands[i] = (CiphertextMatrix) encrypted_models.get(i).getData(matrix_idx); + } + result[matrix_idx] = _seal_server.accumulateCiphertexts(summands);; + }); + if (tAgg != null) { + ParamServStatistics.accHEAccumulation((long)tAgg.stop()); + } + return result; + } + + private Void homomorphicAverage(CiphertextMatrix[] encrypted_sums, List<PlaintextMatrix[]> partial_decryptions) { + Timing tDecrypt = DMLScript.STATISTICS ? new Timing(true) : null; + + MatrixObject[] result = new MatrixObject[partial_decryptions.get(0).length]; + + IntStream.range(0, partial_decryptions.get(0).length).forEach(matrix_idx -> { + PlaintextMatrix[] partial_plaintexts = new PlaintextMatrix[partial_decryptions.size()]; + for (int i = 0; i < partial_decryptions.size(); i++) { + partial_plaintexts[i] = partial_decryptions.get(i)[matrix_idx]; + } + + result[matrix_idx] = _seal_server.average(encrypted_sums[matrix_idx], partial_plaintexts); + }); + + ListObject old_model = getResult(); + ListObject new_model = new ListObject(old_model); + for (int i = 0; i < new_model.getLength(); i++) { + new_model.set(i, result[i]); + } + + if (tDecrypt != null) { + ParamServStatistics.accHEDecryptionTime((long)tDecrypt.stop()); + } + + updateAndBroadcastModel(new_model, null); + return null; + } + + // this is only to be used in push() + private Timing commTimer; + private void startCommTimer() { + commTimer = new Timing(true); + } + private long stopCommTimer() { + return (long)commTimer.stop(); + } + // --------------------------------- + + @Override + public void push(int workerID, ListObject encrypted_model) { + // wait for all updates and sum them homomorphically + CiphertextMatrix[] homomorphic_sum = collectAndDo(workerID, encrypted_model, x -> { + CiphertextMatrix[] res = this.homomorphicAggregation(x); + this.startCommTimer(); + return res; + }); + + // get partial decryptions + PlaintextMatrix[] partial_decryption = _threads.get(workerID).getPartialDecryption(homomorphic_sum); + + // do average and update global model + collectAndDo(workerID, partial_decryption, x -> { + ParamServStatistics.accFedNetworkTime(this.stopCommTimer()); + return this.homomorphicAverage(homomorphic_sum, x); + }); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java index 50c76a0f427..9fd49ca0d10 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java @@ -39,7 +39,7 @@ public static LocalParamServer create(ListObject model, String aggFunc, Statemen workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg); } - private LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, + protected LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg) { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NativeHEHelper.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NativeHEHelper.java new file mode 100644 index 00000000000..7757ad722bb --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NativeHEHelper.java @@ -0,0 +1,100 @@ +package org.apache.sysds.runtime.controlprogram.paramserv; + +import org.apache.commons.lang.SystemUtils; +import org.apache.sysds.utils.NativeHelper; + +public class NativeHEHelper { + public static boolean initialize() { + String platform_suffix = (SystemUtils.IS_OS_WINDOWS ? "-Windows-AMD64.dll" : "-Linux-x86_64.so"); + String library_name = "libhe" + platform_suffix; + return NativeHelper.loadLibraryHelperFromResource(library_name); + } + + // ---------------------------------------------------------------------------------------------------------------- + // SEAL integration + // ---------------------------------------------------------------------------------------------------------------- + + // these are called by SEALClient + + /** + * generates a Client object + * @param a a constant generated by generateA + * @return a pointer to the native client object + */ + public static native long initClient(byte[] a); + + /** + * generates a partial public key + * stores a partial private key corresponding to the partial public key in client + * @param client A pointer to a Client, obtained from initClient + * @return a serialized partial public key + */ + public static native byte[] generatePartialPublicKey(long client); + + /** + * sets the public key and stores it in client + * @param client A pointer to a Client, obtained from initClient + * @param public_key serialized public key obtained from generatePartialPublicKey + */ + public static native void setPublicKey(long client, byte[] public_key); + + /** + * encrypts data with public key stored in client + * setPublicKey() must have been called before calling this + * @param client A pointer to a Client, obtained from initClient + * @param plaintexts array of double values to be encrypted + * @return serialized ciphertext + */ + public static native byte[] encrypt(long client, double[] plaintexts); + + /** + * partially decrypts ciphertexts with the partial private key. generatePartialPublicKey() must + * have been called before calling this function + * @param client A pointer to a Client, obtained from initClient + * @param ciphertext serialized ciphertext + * @return serialized partial decryption + */ + public static native byte[] partiallyDecrypt(long client, byte[] ciphertext); + + // ---------------------------------------------------------------------------------------------------------------- + + // these are called by SEALServer + + /** + * generates the Server object and returns a pointer to it + * @return pointer to a native Server object + */ + public static native long initServer(); + + /** + * this generates the a constant. in a future version we want to generate this together with the clients to prevent misuse + * @param server A pointer to a Server, obtained from initServer + * @return serialized a constant + */ + public static native byte[] generateA(long server); + + /** + * accumulates the given partial public keys into a public key, stores it in server and returns it + * @param server A pointer to a Server, obtained from initServer + * @param partial_public_keys array of serialized partial public keys + * @return serialized partial public key + */ + public static native byte[] aggregatePartialPublicKeys(long server, byte[][] partial_public_keys); + + /** + * accumulates the given ciphertexts into a sum ciphertext and returns it + * @param server A pointer to a Server, obtained from initServer + * @param ciphertexts array of serialized ciphertexts + * @return serialized accumulated ciphertext + */ + public static native byte[] accumulateCiphertexts(long server, byte[][] ciphertexts); + + /** + * averages the partial decryptions and returns the result + * @param server A pointer to a Server, obtained from initServer + * @param encrypted_sum the result of accumulateCiphertexts() + * @param partial_plaintexts the result of partiallyDecrypt of each ciphertext fed into accumulateCiphertexts + * @return average of original data + */ + public static native double[] average(long server, byte[] encrypted_sum, byte[][] partial_plaintexts); +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java new file mode 100644 index 00000000000..d0ba01dba93 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java @@ -0,0 +1,23 @@ +package org.apache.sysds.runtime.controlprogram.paramserv; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.traffic.ChannelTrafficShapingHandler; +import java.util.function.BiConsumer; + +public class NetworkTrafficCounter extends ChannelTrafficShapingHandler { + private final BiConsumer<Long, Long> _fn; // (read, written) -> Void, logs bytes read and written + public NetworkTrafficCounter(BiConsumer<Long, Long> fn) { + // checkInterval of zero means that doAccounting will not be called + super( 0); + _fn = fn; + } + + // log bytes read/written after channel is closed + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + _fn.accept(trafficCounter.cumulativeReadBytes(), trafficCounter.cumulativeWrittenBytes()); + trafficCounter.resetCumulativeTime(); + super.channelInactive(ctx); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java index 009dc20a338..0e09fabf30b 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java @@ -19,10 +19,7 @@ package org.apache.sysds.runtime.controlprogram.paramserv; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; +import java.util.*; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.stream.Collectors; @@ -326,30 +323,8 @@ protected synchronized void updateAverageModel(int workerID, ListObject model) { _accModels = ParamservUtils.accrueGradients(_accModels, weightParams, true); if(allFinished()) { - _model = setParams(_ec, _accModels, _model); - if (DMLScript.STATISTICS && tAgg != null) - ParamServStatistics.accAggregationTime((long) tAgg.stop()); - _accModels = null; //reset for next accumulation - - // This if has grown to be quite complex its function is rather simple. Validate at the end of each epoch - // In the BSP batch case that occurs after the sync counter reaches the number of batches and in the - // BSP epoch case every time - if(_numBatchesPerEpoch != -1 && (_freq == Statement.PSFrequency.EPOCH || (_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) { - - if(LOG.isInfoEnabled()) - LOG.info("[+] PARAMSERV: completed EPOCH " + _epochCounter); - time_epoch(); - if(_validationPossible) { - validate(); - } - _epochCounter++; - _syncCounter = 0; - } - // Broadcast the updated model + updateAndBroadcastModel(_accModels, tAgg); resetFinishedStates(); - broadcastModel(true); - if(LOG.isDebugEnabled()) - LOG.debug("Global parameter is broadcasted successfully "); } break; } @@ -365,7 +340,33 @@ protected synchronized void updateAverageModel(int workerID, ListObject model) { } } - protected ListObject weightModels(ListObject params, int numWorkers) { + protected void updateAndBroadcastModel(ListObject new_model, Timing tAgg) { + _model = setParams(_ec, new_model, _model); + if (DMLScript.STATISTICS && tAgg != null) + ParamServStatistics.accAggregationTime((long) tAgg.stop()); + _accModels = null; //reset for next accumulation + + // This if has grown to be quite complex its function is rather simple. Validate at the end of each epoch + // In the BSP batch case that occurs after the sync counter reaches the number of batches and in the + // BSP epoch case every time + if(_numBatchesPerEpoch != -1 && (_freq == Statement.PSFrequency.EPOCH || (_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) { + + if(LOG.isInfoEnabled()) + LOG.info("[+] PARAMSERV: completed EPOCH " + _epochCounter); + time_epoch(); + if(_validationPossible) { + validate(); + } + _epochCounter++; + _syncCounter = 0; + } + // Broadcast the updated model + broadcastModel(true); + if(LOG.isDebugEnabled()) + LOG.debug("Global parameter is broadcasted successfully "); + } + + protected ListObject weightModels(ListObject params, int numWorkers) { double _averagingFactor = 1d / numWorkers; if( _averagingFactor != 1) { @@ -472,6 +473,10 @@ private void validate() { ParamServStatistics.accValidationTime((long) tValidate.stop()); } + public int getNumWorkers() { + return _numWorkers; + } + public FunctionCallCPInstruction getAggInst() { return _inst; } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java index 5bb3e12dcab..96979e3a5d8 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java @@ -87,6 +87,7 @@ static List<MatrixObject> sliceFederatedMatrix(MatrixObject fedMatrix) { new MatrixCharacteristics(range.getSize(0), range.getSize(1)), Types.FileFormat.BINARY) ); + slice.setPrivacyConstraints(fedMatrix.getPrivacyConstraint()); // Create new federation map List<Pair<FederatedRange, FederatedData>> newFedHashMap = new ArrayList<>(); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/PublicKey.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/PublicKey.java new file mode 100644 index 00000000000..96fd415308e --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/PublicKey.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption; + +import java.io.Serializable; + +public class PublicKey implements Serializable { + private static final long serialVersionUID = 91289081237980123L; + + private final byte[] _data; + + public PublicKey(byte[] data) { + _data = data; + } + + public byte[] getData() { + return _data; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java new file mode 100644 index 00000000000..935f2808af5 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption; + +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix; +import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +import java.util.stream.IntStream; + +public class SEALClient { + public SEALClient(byte[] a) { + // TODO take params here, like slot_count etc. + ctx = NativeHEHelper.initClient(a); + } + + // this is a pointer to the context used by all native methods of this class + private final long ctx; + + + /** + * generates a partial public key + * stores a partial private key corresponding to the partial public key in ctx + * + * @return the partial public key + */ + public PublicKey generatePartialPublicKey() { + return new PublicKey(NativeHEHelper.generatePartialPublicKey(ctx)); + } + + /** + * sets the public key and stores it in ctx + * + * @param public_key the public key to set + */ + public void setPublicKey(PublicKey public_key) { + NativeHEHelper.setPublicKey(ctx, public_key.getData()); + } + + /** + * encrypts one block of data with public key stored statically and returns it + * setPublicKey() must have been called before calling this + * @param plaintext the MatrixObject to encrypt + * @return the encrypted matrix + */ + public CiphertextMatrix encrypt(MatrixObject plaintext) { + MatrixBlock mb = plaintext.acquireReadAndRelease(); + if (mb.isInSparseFormat()) { + mb.allocateSparseRowsBlock(); + mb.sparseToDense(); + } + DenseBlock db = mb.getDenseBlock(); + int[] dims = IntStream.range(0, db.numDims()).map(db::getDim).toArray(); + double[] raw_data = mb.getDenseBlockValues(); + return new CiphertextMatrix(dims, plaintext.getDataCharacteristics(), NativeHEHelper.encrypt(ctx, raw_data)); + } + + /** + * partially decrypts ciphertext with the partial private key. generatePartialPublicKey() must + * have been called before calling this function + * + * @param ciphertext the ciphertext to partially decrypt + * @return the partial decryption of ciphertext + */ + public PlaintextMatrix partiallyDecrypt(CiphertextMatrix ciphertext) { + return new PlaintextMatrix(ciphertext.getDims(), ciphertext.getDataCharacteristics(), NativeHEHelper.partiallyDecrypt(ctx, ciphertext.getData())); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALServer.java new file mode 100644 index 00000000000..d6265c7f6d7 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALServer.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption; + +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.DenseBlockFactory; +import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix; +import org.apache.sysds.runtime.instructions.cp.Encrypted; +import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataFormat; + +import java.util.Arrays; + +public class SEALServer { + public SEALServer() { + // TODO take params here, like slot_count etc. + ctx = NativeHEHelper.initServer(); + } + + // this is a pointer to the context used by all native methods of this class + private final long ctx; + private byte[] _a; + + /** + * this generates the a constant. in a future version we want to generate this together with the clients to prevent misuse + * @return serialized a constant + */ + public synchronized byte[] generateA() { + if (_a == null) { + _a = NativeHEHelper.generateA(ctx); + } + return _a; + } + + /** + * accumulates the given partial public keys into a public key, stores it in ctx and returns it + * @param partial_public_keys an array of partial public keys generated with SEALServer::generatePartialPublicKey + * @return the aggregated public key + */ + public PublicKey aggregatePartialPublicKeys(PublicKey[] partial_public_keys) { + return new PublicKey(NativeHEHelper.aggregatePartialPublicKeys(ctx, extractRawData(partial_public_keys))); + } + + /** + * accumulates the given ciphertext blocks into a sum ciphertext and returns it + * @param ciphertexts ciphertexts encrypted with the partial public keys + * @return the accumulated ciphertext (which is the homomorphic sum of ciphertexts) + */ + public CiphertextMatrix accumulateCiphertexts(CiphertextMatrix[] ciphertexts) { + return new CiphertextMatrix(ciphertexts[0].getDims(), ciphertexts[0].getDataCharacteristics(), NativeHEHelper.accumulateCiphertexts(ctx, extractRawData(ciphertexts))); + } + + /** + * averages the partial decryptions + * @param encrypted_sum is the result of accumulateCiphertexts() + * @param partial_plaintexts is the result of SEALServer::partiallyDecrypt of each ciphertext fed into accumulateCiphertexts + * @return the unencrypted, element-wise average of the original matrices + */ + public MatrixObject average(CiphertextMatrix encrypted_sum, PlaintextMatrix[] partial_plaintexts) { + double[] raw_result = NativeHEHelper.average(ctx, encrypted_sum.getData(), extractRawData(partial_plaintexts)); + int[] dims = encrypted_sum.getDims(); + int result_len = Arrays.stream(dims).reduce(1, (x,y) -> x*y); + DataCharacteristics dc = encrypted_sum.getDataCharacteristics(); + + DenseBlock new_dense_block = DenseBlockFactory.createDenseBlock(Arrays.copyOf(raw_result, result_len), dims); + MatrixBlock new_matrix_block = new MatrixBlock((int)dc.getRows(), (int)dc.getCols(), new_dense_block); + MatrixObject new_mo = new MatrixObject(Types.ValueType.FP64, OptimizerUtils.getUniqueTempFileName(), new MetaDataFormat(dc, Types.FileFormat.BINARY)); + new_mo.acquireModify(new_matrix_block); + new_mo.release(); + return new_mo; + } + + private static byte[][] extractRawData(Encrypted[] data) { + byte[][] raw_data = new byte[data.length][]; + for (int i = 0; i < data.length; i++) { + raw_data[i] = data[i].getData(); + } + return raw_data; + } + + // TODO: extract an interface for this and use it here + private static byte[][] extractRawData(PublicKey[] data) { + byte[][] raw_data = new byte[data.length][]; + for (int i = 0; i < data.length; i++) { + raw_data[i] = data[i].getData(); + } + return raw_data; + } +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CiphertextMatrix.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CiphertextMatrix.java new file mode 100644 index 00000000000..1cbef9d488e --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CiphertextMatrix.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +/** + * This class abstracts over an encrypted matrix of ciphertexts. It stores the data as opaque byte array. The layout is unspecified. + */ +public class CiphertextMatrix extends Encrypted { + private static final long serialVersionUID = 1762936872261940616L; + + public CiphertextMatrix(int[] dims, DataCharacteristics dc, byte[] data) { + super(dims, dc, data, Types.DataType.ENCRYPTED_CIPHER); + } + + @Override + public String getDebugName() { + return "CiphertextMatrix " + getData().hashCode(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/Encrypted.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/Encrypted.java new file mode 100644 index 00000000000..eb7d1ea44a5 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/Encrypted.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +/** + * This class abstracts over an encrypted data. It stores the data as opaque byte array. The layout is unspecified. + */ +public abstract class Encrypted extends Data { + private static final long serialVersionUID = 1762936872268046168L; + + private final int[] _dims; + private final DataCharacteristics _dc; + private final byte[] _data; + + public Encrypted(int[] dims, DataCharacteristics dc, byte[] data, Types.DataType dt) { + super(dt, Types.ValueType.UNKNOWN); + _dims = dims; + _dc = dc; + _data = data; + } + + public int[] getDims() { + return _dims; + } + + public DataCharacteristics getDataCharacteristics() { + return _dc; + } + + public byte[] getData() { + return _data; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java index 38288178e4e..5c302fe80a9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java @@ -397,6 +397,18 @@ public void writeExternal(ObjectOutput out) throws IOException { ScalarObject so = (ScalarObject) d; out.writeObject(so.getStringValue()); break; + case ENCRYPTED_CIPHER: + case ENCRYPTED_PLAIN: + Encrypted e = (Encrypted) d; + int[] dims = e.getDims(); + dc = e.getDataCharacteristics(); + out.writeObject(dims); + out.writeObject(dc.getRows()); + out.writeObject(dc.getCols()); + out.writeObject(dc.getBlocksize()); + out.writeObject(dc.getNonZeros()); + out.writeObject(e.getData()); + break; default: throw new DMLRuntimeException("Unable to serialize datatype " + dataType); } @@ -463,6 +475,21 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept } d = so; break; + case ENCRYPTED_CIPHER: + case ENCRYPTED_PLAIN: + int[] dims = (int[]) in.readObject(); + rows = (long) in.readObject(); + cols = (long) in.readObject(); + blockSize = (int) in.readObject(); + nonZeros = (long) in.readObject(); + byte[] data = (byte[])in.readObject(); + DataCharacteristics dc = new MatrixCharacteristics(rows, cols, blockSize, nonZeros); + if (dataType == DataType.ENCRYPTED_CIPHER) { + d = new CiphertextMatrix(dims, dc, data); + } else { + d = new PlaintextMatrix(dims, dc, data); + } + break; default: throw new DMLRuntimeException("Unable to deserialize datatype " + dataType); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 25353f60039..d16aa9ec4e2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -19,28 +19,6 @@ package org.apache.sysds.runtime.instructions.cp; -import static org.apache.sysds.parser.Statement.PS_AGGREGATION_FUN; -import static org.apache.sysds.parser.Statement.PS_BATCH_SIZE; -import static org.apache.sysds.parser.Statement.PS_EPOCHS; -import static org.apache.sysds.parser.Statement.PS_FEATURES; -import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING; -import static org.apache.sysds.parser.Statement.PS_FED_WEIGHTING; -import static org.apache.sysds.parser.Statement.PS_FREQUENCY; -import static org.apache.sysds.parser.Statement.PS_HYPER_PARAMS; -import static org.apache.sysds.parser.Statement.PS_LABELS; -import static org.apache.sysds.parser.Statement.PS_MODE; -import static org.apache.sysds.parser.Statement.PS_MODEL; -import static org.apache.sysds.parser.Statement.PS_MODELAVG; -import static org.apache.sysds.parser.Statement.PS_NBATCHES; -import static org.apache.sysds.parser.Statement.PS_PARALLELISM; -import static org.apache.sysds.parser.Statement.PS_SCHEME; -import static org.apache.sysds.parser.Statement.PS_SEED; -import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN; -import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE; -import static org.apache.sysds.parser.Statement.PS_VAL_FEATURES; -import static org.apache.sysds.parser.Statement.PS_VAL_FUN; -import static org.apache.sysds.parser.Statement.PS_VAL_LABELS; - import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; @@ -71,25 +49,22 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysds.runtime.controlprogram.paramserv.FederatedPSControlThread; -import org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker; -import org.apache.sysds.runtime.controlprogram.paramserv.LocalParamServer; -import org.apache.sysds.runtime.controlprogram.paramserv.ParamServer; -import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSBody; -import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSWorker; -import org.apache.sysds.runtime.controlprogram.paramserv.SparkParamservUtils; +import org.apache.sysds.runtime.controlprogram.paramserv.*; import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme; import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme; import org.apache.sysds.runtime.controlprogram.paramserv.dp.FederatedDataPartitioner; import org.apache.sysds.runtime.controlprogram.paramserv.dp.LocalDataPartitioner; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey; import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcFactory; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; import org.apache.sysds.runtime.util.ProgramConverter; import org.apache.sysds.utils.stats.ParamServStatistics; +import static org.apache.sysds.parser.Statement.*; + public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction { private static final Log LOG = LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName()); @@ -102,6 +77,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP; public static final int DEFAULT_NBATCHES = 1; private static final Boolean DEFAULT_MODELAVG = false; + private static final Boolean DEFAULT_HE = false; public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr) { super(op, paramsMap, out, opcode, istr); @@ -188,23 +164,56 @@ private void runFederated(ExecutionContext ec) { MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null) ? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null; MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null; boolean modelAvg = Boolean.parseBoolean(getParam(PS_MODELAVG)); - ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, freq, workerNum, model, aggServiceEC, getValFunction(), - getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics), val_features, val_labels, nbatches, modelAvg); + + // check if we need homomorphic encryption + boolean use_homomorphic_encryption_ = getHe(); + for (int i = 0; i < workerNum; i++) { + use_homomorphic_encryption_ = use_homomorphic_encryption_ || checkIsPrivate(result._pFeatures.get(i)); + use_homomorphic_encryption_ = use_homomorphic_encryption_ || checkIsPrivate(result._pLabels.get(i)); + } + final boolean use_homomorphic_encryption = use_homomorphic_encryption_; + if (use_homomorphic_encryption && !modelAvg) { + throw new DMLRuntimeException("can't use homomorphic encryption without modelAvg"); + } + + if (use_homomorphic_encryption && weighting) { + throw new DMLRuntimeException("can't use homomorphic encryption with weighting"); + } + + LocalParamServer ps = (LocalParamServer)createPS(PSModeType.FEDERATED, aggFunc, updateType, freq, workerNum, model, aggServiceEC, getValFunction(), + getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics), val_features, val_labels, nbatches, modelAvg, use_homomorphic_encryption); // Create the local workers int finalNumBatchesPerEpoch = getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics); List<FederatedPSControlThread> threads = IntStream.range(0, workerNum) .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, weighting, - getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps, nbatches, modelAvg)) + getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps, nbatches, modelAvg, use_homomorphic_encryption)) .collect(Collectors.toList()); if(workerNum != threads.size()) { throw new DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning does not match threads!"); } + // Set features and lables for the control threads and write the program and instructions and hyperparams to the federated workers for (int i = 0; i < threads.size(); i++) { threads.get(i).setFeatures(result._pFeatures.get(i)); threads.get(i).setLabels(result._pLabels.get(i)); threads.get(i).setup(result._weightingFactors.get(i)); } + + if (use_homomorphic_encryption) { + // generate public key from partial public keys + PublicKey[] partial_public_keys = new PublicKey[threads.size()]; + for (int i = 0; i < threads.size(); i++) { + partial_public_keys[i] = threads.get(i).getPartialPublicKey(); + } + + // TODO: accumulate public keys with SEAL + PublicKey public_key = ((HEParamServer)ps).aggregatePartialPublicKeys(partial_public_keys); + + for (FederatedPSControlThread thread : threads) { + thread.setPublicKey(public_key); + } + } + if (DMLScript.STATISTICS) ParamServStatistics.accSetupTime((long) tSetup.stop()); @@ -479,21 +488,32 @@ private int getWorkerNum(PSModeType mode) { * @return parameter server */ private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, - PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, int nbatches, boolean modelAvg) + PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, int nbatches, boolean modelAvg) { return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, null, -1, null, null, nbatches, modelAvg); } + + private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, + PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, String valFunc, + int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg) { + return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg, false); + } + // When this creation is used the parameter server is able to validate after each epoch private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, String valFunc, - int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg) + int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg, boolean use_homomorphic_encryption) { switch (mode) { case FEDERATED: case LOCAL: case REMOTE_SPARK: - return LocalParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg); + if (use_homomorphic_encryption) { + return HEParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches); + } else { + return LocalParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg); + } default: throw new DMLRuntimeException("Unsupported parameter server: " + mode.name()); } @@ -614,4 +634,15 @@ private int getNbatches() { } return Integer.parseInt(getParam(PS_NBATCHES)); } + + private boolean checkIsPrivate(MatrixObject obj) { + PrivacyConstraint pc = obj.getPrivacyConstraint(); + return pc != null && pc.hasPrivateElements(); + } + + private boolean getHe() { + if(!getParameterMap().containsKey(PS_HE)) + return DEFAULT_HE; + return Boolean.parseBoolean(getParam(PS_HE)); + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java new file mode 100644 index 00000000000..6fe2b3814f4 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +/** + * This class abstracts over an encrypted matrix of ciphertexts. It stores the data as opaque byte array. The layout is unspecified. + */ +public class PlaintextMatrix extends Encrypted { + private static final long serialVersionUID = 5732436872261940616L; + + public PlaintextMatrix(int[] dims, DataCharacteristics dc, byte[] data) { + super(dims, dc, data, Types.DataType.ENCRYPTED_PLAIN); + } + + @Override + public String getDebugName() { + return "PlaintextMatrix " + getData().hashCode(); + } +} diff --git a/src/main/java/org/apache/sysds/utils/NativeHelper.java b/src/main/java/org/apache/sysds/utils/NativeHelper.java index 83869c23d2f..e86bd56b19d 100644 --- a/src/main/java/org/apache/sysds/utils/NativeHelper.java +++ b/src/main/java/org/apache/sysds/utils/NativeHelper.java @@ -44,14 +44,14 @@ * By default, it first tries to load Intel MKL, else tries to load OpenBLAS. */ public class NativeHelper { - + public enum NativeBlasState { NOT_ATTEMPTED_LOADING_NATIVE_BLAS, SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE, SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE, ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY } - + public static NativeBlasState CURRENT_NATIVE_BLAS_STATE = NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS; private static String blasType; @@ -63,16 +63,16 @@ public enum NativeBlasState { /** * Called by Statistics to print the loaded BLAS. - * + * * @return empty string or the BLAS that is loaded */ public static String getCurrentBLAS() { return blasType != null ? blasType : ""; } - + /** * Called by runtime to check if the BLAS is available for exploitation - * + * * @return true if CURRENT_NATIVE_BLAS_STATE is SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE else false */ public static boolean isNativeLibraryLoaded() { @@ -99,10 +99,10 @@ public static boolean isNativeLibraryLoaded() { } return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE; } - + /** - * Initialize the native library before executing the DML program - * + * Initialize the native library before executing the DML program + * * @param customLibPath specified by sysds.native.blas.directory * @param userSpecifiedBLAS specified by sysds.native.blas */ @@ -121,22 +121,22 @@ else if(!isBLASLoaded() && isSupportedBLAS(userSpecifiedBLAS)) { performLoading(customLibPath, userSpecifiedBLAS); } } - + /** * Return true if the given BLAS type is supported. - * + * * @param userSpecifiedBLAS BLAS type specified via sysds.native.blas property * @return true if the userSpecifiedBLAS is auto | mkl | openblas, else false */ private static boolean isSupportedBLAS(String userSpecifiedBLAS) { - return userSpecifiedBLAS.equalsIgnoreCase("auto") || - userSpecifiedBLAS.equalsIgnoreCase("mkl") || + return userSpecifiedBLAS.equalsIgnoreCase("auto") || + userSpecifiedBLAS.equalsIgnoreCase("mkl") || userSpecifiedBLAS.equalsIgnoreCase("openblas"); } - + /** * Note: we only support 64 bit Java on x86 and AMD machine - * + * * @return true if the hardware architecture is supported */ private static boolean isSupportedArchitecture() { @@ -166,21 +166,21 @@ private static boolean isSupportedOS() { * SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE */ private static boolean isBLASLoaded() { - return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE || + return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE || CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE; } - + /** * Check if we should attempt to perform loading. * If custom library path is provided, we should attempt to load again if not already loaded. - * - * @param customLibPath custom library path + * + * @param customLibPath custom library path * @return true if we should attempt to load blas again */ private static boolean shouldReload(String customLibPath) { boolean isValidBLASDirectory = customLibPath != null && !customLibPath.equalsIgnoreCase("none"); return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS || - (isValidBLASDirectory && !isBLASLoaded()); + (isValidBLASDirectory && !isBLASLoaded()); } // Performing loading in a method instead of a static block will throw a detailed stack trace in case of fatal errors @@ -191,13 +191,13 @@ private static void performLoading(String customLibPath, String userSpecifiedBLA // attemptedLoading variable ensures that we don't try to load SystemDS and other dependencies // again and again especially in the parfor (hence the double-checking with synchronized). if(shouldReload(customLibPath) && isSupportedBLAS(userSpecifiedBLAS) && isSupportedArchitecture() - && isSupportedOS()) { + && isSupportedOS()) { long start = System.nanoTime(); synchronized(NativeHelper.class) { if(shouldReload(customLibPath)) { // Set attempted loading unsuccessful in case of exception CURRENT_NATIVE_BLAS_STATE = NativeBlasState.ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY; - String [] blas = new String[] { userSpecifiedBLAS }; + String [] blas = new String[] { userSpecifiedBLAS }; if(userSpecifiedBLAS.equalsIgnoreCase("auto")) { blas = new String[] { "mkl", "openblas" }; } @@ -206,7 +206,7 @@ && isSupportedOS()) { String platform_suffix = (SystemUtils.IS_OS_WINDOWS ? "-Windows-AMD64.dll" : "-Linux-x86_64.so"); String library_name = "libsystemds_" + blasType + platform_suffix; if(loadLibraryHelperFromResource(library_name) || - loadBLAS(customLibPath, library_name,"Loading native helper with customLibPath.")) + loadBLAS(customLibPath, library_name,"Loading native helper with customLibPath.")) { LOG.info("Using native blas: " + blasType + getNativeBLASPath()); CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE; @@ -215,15 +215,15 @@ && isSupportedOS()) { } } double timeToLoadInMilliseconds = (System.nanoTime()-start)*1e-6; - if(timeToLoadInMilliseconds > 1000) + if(timeToLoadInMilliseconds > 1000) LOG.warn("Time to load native blas: " + timeToLoadInMilliseconds + " milliseconds."); } else if(LOG.isDebugEnabled() && !isSupportedBLAS(userSpecifiedBLAS)) { LOG.debug("Using internal Java BLAS as native BLAS support instead of the configuration " + - "'sysds.native.blas'=" + userSpecifiedBLAS + "."); + "'sysds.native.blas'=" + userSpecifiedBLAS + "."); } } - + private static boolean checkAndLoadBLAS(String customLibPath, String [] listBLAS) { if(customLibPath != null && customLibPath.equalsIgnoreCase("none")) customLibPath = null; @@ -250,10 +250,10 @@ else if (blas.equalsIgnoreCase("openblas")) { } return isLoaded; } - + /** * Useful method for debugging. - * + * * @return empty string (if !LOG.isDebugEnabled()) or the path from where openblas or mkl is loaded. */ private static String getNativeBLASPath() { @@ -287,8 +287,8 @@ public static int getMaxNumThreads() { /** * Attempts to load native BLAS - * - * @param customLibPath can be null (if we want to only want to use LD_LIBRARY_PATH), else the + * + * @param customLibPath can be null (if we want to only want to use LD_LIBRARY_PATH), else the * @param blas can be gomp, openblas or mkl_rt * @param optionalMsg message for debugging * @return true if successfully loaded BLAS @@ -300,8 +300,8 @@ public static boolean loadBLAS(String customLibPath, String blas, String optiona try { // This fixes libPath if it already contained a prefix/suffix and mapLibraryName added another one. libPath = libPath.replace("liblibsystemds", "libsystemds") - .replace(".dll.dll", ".dll") - .replace(".so.so", ".so"); + .replace(".dll.dll", ".dll") + .replace(".so.so", ".so"); System.load(libPath); LOG.info("Loaded the library:" + libPath); return true; @@ -321,7 +321,7 @@ public static boolean loadBLAS(String customLibPath, String blas, String optiona catch (UnsatisfiedLinkError e) { LOG.debug("java.library.path: " + System.getProperty("java.library.path")); LOG.debug("Unable to load " + blas + (optionalMsg == null ? "" : (" (" + optionalMsg + ")")) + - " \n Message from exception was: " + e.getMessage()); + " \n Message from exception was: " + e.getMessage()); return false; } } @@ -355,13 +355,13 @@ public static boolean loadLibraryHelperFromResource(String libFileName) { } // TODO: Add pmm, wsloss, mmchain, etc. - + //double-precision matrix multiply dense-dense public static native long dmmdd(double [] m1, double [] m2, double [] ret, int m1rlen, int m1clen, int m2clen, - int numThreads); + int numThreads); //single-precision matrix multiply dense-dense public static native long smmdd(FloatBuffer m1, FloatBuffer m2, FloatBuffer ret, int m1rlen, int m1clen, int m2clen, - int numThreads); + int numThreads); //transpose-self matrix multiply public static native long tsmm(double[] m1, double[] ret, int m1rlen, int m1clen, boolean leftTrans, int numThreads); @@ -374,27 +374,27 @@ public static native long smmdd(FloatBuffer m1, FloatBuffer m2, FloatBuffer ret, // Returns -1 if failures or returns number of nonzeros // Called by DnnCPInstruction if both input and filter are dense - public static native long conv2dDense(double [] input, double [] filter, double [] ret, int N, int C, int H, int W, - int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads); + public static native long conv2dDense(double [] input, double [] filter, double [] ret, int N, int C, int H, int W, + int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads); public static native long dconv2dBiasAddDense(double [] input, double [] bias, double [] filter, double [] ret, int N, - int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, - int numThreads); + int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, + int numThreads); public static native long sconv2dBiasAddDense(FloatBuffer input, FloatBuffer bias, FloatBuffer filter, FloatBuffer ret, - int N, int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, - int numThreads); + int N, int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, + int numThreads); // Called by DnnCPInstruction if both input and filter are dense public static native long conv2dBackwardFilterDense(double [] input, double [] dout, double [] ret, int N, int C, - int H, int W, int K, int R, int S, int stride_h, int stride_w, - int pad_h, int pad_w, int P, int Q, int numThreads); + int H, int W, int K, int R, int S, int stride_h, int stride_w, + int pad_h, int pad_w, int P, int Q, int numThreads); // If both filter and dout are dense, then called by DnnCPInstruction // Else, called by LibMatrixDNN's thread if filter is dense. dout[n] is converted to dense if sparse. public static native long conv2dBackwardDataDense(double [] filter, double [] dout, double [] ret, int N, int C, - int H, int W, int K, int R, int S, int stride_h, int stride_w, - int pad_h, int pad_w, int P, int Q, int numThreads); + int H, int W, int K, int R, int S, int stride_h, int stride_w, + int pad_h, int pad_w, int P, int Q, int numThreads); // Currently only supported with numThreads = 1 and sparse input // Called by LibMatrixDNN's thread if input is sparse. dout[n] is converted to dense if sparse. @@ -415,4 +415,4 @@ public static native boolean conv2dSparse(int apos, int alen, int[] aix, double[ // different tradeoffs. In current implementation, we always use GetPrimitiveArrayCritical as it has proven to be // fastest. We can revisit this decision later and hence I would not recommend removing this method. private static native void setMaxNumThreads(int numThreads); -} +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java index 77b63a99213..aece9b655a6 100644 --- a/src/main/java/org/apache/sysds/utils/Statistics.java +++ b/src/main/java/org/apache/sysds/utils/Statistics.java @@ -674,6 +674,8 @@ public static String display(int maxHeavyHitters) if(DMLScript.FED_STATISTICS) { sb.append("\n"); sb.append(FederatedStatistics.displayStatistics(DMLScript.FED_STATISTICS_COUNT)); + sb.append("\n"); + sb.append(ParamServStatistics.displayFloStatistics()); } return sb.toString(); diff --git a/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java b/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java index 0d97bfd0c63..8eb26a19637 100644 --- a/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java +++ b/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.LongAdder; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; public class ParamServStatistics { @@ -41,6 +42,14 @@ public class ParamServStatistics { private static final LongAdder fedWorkerComputingTime = new LongAdder(); private static final LongAdder fedGradientWeightingTime = new LongAdder(); private static final LongAdder fedCommunicationTime = new LongAdder(); + private static final LongAdder fedNetworkTime = new LongAdder(); // measures exactly how long it takes netty to send & receive data + // Homomorphic encryption specifics (time is in milli sec) + private static final LongAdder heEncryption = new LongAdder(); // SEALClient::encrypt + private static final LongAdder heAccumulation = new LongAdder(); // SEALServer::accumulateCiphertexts + private static final LongAdder hePartialDecryption = new LongAdder(); // SEALClient::partiallyDecrypt + private static final LongAdder heDecryption = new LongAdder(); // SEALServer::average + + private static final LongAdder fedAggregation = new LongAdder(); // SEALServer::average public static void incWorkerNumber() { numWorkers.increment(); @@ -110,6 +119,14 @@ public static void accFedWorkerComputing(long t) { fedWorkerComputingTime.add(t); } + public static void accFedNetworkTime(long t) { + fedNetworkTime.add(t); + } + + public static void accFedAggregation(long t) { + fedAggregation.add(t); + } + public static void accFedGradientWeightingTime(long t) { fedGradientWeightingTime.add(t); } @@ -118,6 +135,22 @@ public static void accFedCommunicationTime(long t) { fedCommunicationTime.add(t); } + public static void accHEEncryptionTime(long t) { + heEncryption.add(t); + } + + public static void accHEAccumulation(long t) { + heAccumulation.add(t); + } + + public static void accHEPartialDecryptionTime(long t) { + hePartialDecryption.add(t); + } + + public static void accHEDecryptionTime(long t) { + heDecryption.add(t); + } + public static void reset() { executionTime.reset(); numWorkers.reset(); @@ -133,6 +166,12 @@ public static void reset() { fedWorkerComputingTime.reset(); fedGradientWeightingTime.reset(); fedCommunicationTime.reset(); + fedNetworkTime.reset(); + heEncryption.reset(); + heAccumulation.reset(); + hePartialDecryption.reset(); + heDecryption.reset(); + fedAggregation.reset(); } public static String displayStatistics() { @@ -168,4 +207,16 @@ private static String displayFedPSStatistics() { sb.append(String.format("PS fed grad. weigh. time (cum):\t%.3f secs.\n", fedGradientWeightingTime.doubleValue() / 1000)); return sb.toString(); } + + public static String displayFloStatistics() { + StringBuilder sb = new StringBuilder(); + sb.append(String.format("PS fed network time (cum):\t%.3f secs.\n", fedNetworkTime.doubleValue() / 1000)); + sb.append(String.format("PS fed agg time:\t%.3f secs.\n", fedAggregation.doubleValue() / 1000)); + sb.append(String.format("Paramserv grad compute time:\t%.3f secs.\n", gradientComputeTime.doubleValue() / 1000)); + sb.append(String.format("HE PS encryption time:\t%.3f secs.\n", heEncryption.doubleValue() / 1000)); + sb.append(String.format("HE PS accumulation time:\t%.3f secs.\n", heAccumulation.doubleValue() / 1000)); + sb.append(String.format("HE PS partial decryption time:\t%.3f secs.\n", hePartialDecryption.doubleValue() / 1000)); + sb.append(String.format("HE PS decryption time:\t%.3f secs.\n", heDecryption.doubleValue() / 1000)); + return sb.toString(); + } } diff --git a/src/main/python/systemds/operator/algorithm/__init__.py b/src/main/python/systemds/operator/algorithm/__init__.py index feb5342ecca..5352868e805 100644 --- a/src/main/python/systemds/operator/algorithm/__init__.py +++ b/src/main/python/systemds/operator/algorithm/__init__.py @@ -51,7 +51,7 @@ from .builtin.deepWalk import deepWalk from .builtin.denialConstraints import denialConstraints from .builtin.discoverFD import discoverFD -from .builtin.dist import dist +from .builtin.dist import dist from .builtin.dmv import dmv from .builtin.ema import ema from .builtin.executePipeline import executePipeline @@ -60,7 +60,7 @@ from .builtin.fit_pipeline import fit_pipeline from .builtin.fixInvalidLengths import fixInvalidLengths from .builtin.fixInvalidLengthsApply import fixInvalidLengthsApply -from .builtin.frameSort import frameSort +from .builtin.frameSort import frameSort from .builtin.frequencyEncode import frequencyEncode from .builtin.frequencyEncodeApply import frequencyEncodeApply from .builtin.garch import garch @@ -97,8 +97,8 @@ from .builtin.intersect import intersect from .builtin.km import km from .builtin.kmeans import kmeans -from .builtin.kmeansPredict import kmeansPredict -from .builtin.knn import knn +from .builtin.kmeansPredict import kmeansPredict +from .builtin.knn import knn from .builtin.knnGraph import knnGraph from .builtin.knnbf import knnbf from .builtin.l2svm import l2svm @@ -109,12 +109,11 @@ from .builtin.lm import lm from .builtin.lmCG import lmCG from .builtin.lmDS import lmDS -from .builtin.lmPredict import lmPredict from .builtin.logSumExp import logSumExp -from .builtin.matrixProfile import matrixProfile +from .builtin.matrixProfile import matrixProfile from .builtin.mcc import mcc from .builtin.mdedup import mdedup -from .builtin.mice import mice +from .builtin.mice import mice from .builtin.miceApply import miceApply from .builtin.msvm import msvm from .builtin.msvmPredict import msvmPredict @@ -158,11 +157,11 @@ from .builtin.symmetricDifference import symmetricDifference from .builtin.tSNE import tSNE from .builtin.toOneHot import toOneHot -from .builtin.tomeklink import tomeklink +from .builtin.tomeklink import tomeklink from .builtin.topk_cleaning import topk_cleaning from .builtin.underSampling import underSampling from .builtin.union import union -from .builtin.unique import unique +from .builtin.unique import unique from .builtin.univar import univar from .builtin.vectorToCsv import vectorToCsv from .builtin.winsorize import winsorize @@ -261,7 +260,6 @@ 'lm', 'lmCG', 'lmDS', - 'lmPredict', 'logSumExp', 'matrixProfile', 'mcc', diff --git a/src/main/python/systemds/operator/algorithm/builtin/winsorizeApply.py b/src/main/python/systemds/operator/algorithm/builtin/winsorizeApply.py index e09feabf506..83cc2c1c8ba 100644 --- a/src/main/python/systemds/operator/algorithm/builtin/winsorizeApply.py +++ b/src/main/python/systemds/operator/algorithm/builtin/winsorizeApply.py @@ -32,7 +32,11 @@ def winsorizeApply(X: Matrix, qLower: Matrix, qUpper: Matrix): - + """ + :param qLower: lower quantile + :param qUpper: upper quantile + :return: 'OperationNode' containing + """ params_dict = {'X': X, 'qLower': qLower, 'qUpper': qUpper} return Matrix(X.sds_context, 'winsorizeApply', diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index 6ebff8eacd0..c5f7d1a54bb 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -652,6 +652,12 @@ protected void writeFederatedInputMatrix(String name, FederationMap fedMap){ */ protected void rowFederateLocallyAndWriteInputMatrixWithMTD(String name, double[][] matrix, int numFederatedWorkers, List<Integer> ports, double[][] ranges) + { + rowFederateLocallyAndWriteInputMatrixWithMTD(name, matrix, numFederatedWorkers, ports, ranges, null); + } + + protected void rowFederateLocallyAndWriteInputMatrixWithMTD(String name, + double[][] matrix, int numFederatedWorkers, List<Integer> ports, double[][] ranges, PrivacyConstraint privacyConstraint) { // check matrix non empty if(matrix.length == 0 || matrix[0].length == 0) @@ -677,7 +683,7 @@ protected void rowFederateLocallyAndWriteInputMatrixWithMTD(String name, // write slice writeInputMatrixWithMTD(path, Arrays.copyOfRange(matrix, (int)lowerBound, (int)upperBound), false, new MatrixCharacteristics((long) examplesForWorkerI, ncol, - OptimizerUtils.DEFAULT_BLOCKSIZE, (long) examplesForWorkerI * ncol)); + OptimizerUtils.DEFAULT_BLOCKSIZE, (long) examplesForWorkerI * ncol), privacyConstraint); // generate fedmap entry FederatedRange range = new FederatedRange(new long[]{(long) lowerBound, 0}, new long[]{(long) upperBound, ncol}); @@ -688,7 +694,7 @@ false, new MatrixCharacteristics((long) examplesForWorkerI, ncol, federatedMatrixObject.setFedMapping(new FederationMap(FederationUtils.getNextFedDataID(), fedHashMap)); federatedMatrixObject.getFedMapping().setType(FType.ROW); - writeInputFederatedWithMTD(name, federatedMatrixObject, null); + writeInputFederatedWithMTD(name, federatedMatrixObject, privacyConstraint); } protected double[][] generateBalancedFederatedRowRanges(int numFederatedWorkers, int dataSetSize) { diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java new file mode 100644 index 00000000000..25bc5b4ae77 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.paramserv; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.hops.codegen.SpoofCompiler; +import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class EncryptedFederatedParamservTest extends AutomatedTestBase { + // private static final Log LOG = LogFactory.getLog(EncryptedFederatedParamservTest.class.getName()); + private final static String TEST_DIR = "functions/federated/paramserv/"; + private final static String TEST_NAME = "EncryptedFederatedParamservTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + EncryptedFederatedParamservTest.class.getSimpleName() + "/"; + + private final String _networkType; + private final int _numFederatedWorkers; + private final int _dataSetSize; + private final int _epochs; + private final int _batch_size; + private final double _eta; + private final String _utype; + private final String _freq; + private final String _scheme; + private final String _runtime_balancing; + private final String _weighting; + private final String _data_distribution; + private final int _seed; + + // parameters + @Parameterized.Parameters + public static Collection<Object[]> parameters() { + return Arrays.asList(new Object[][] { + // Network type, number of federated workers, data set size, batch size, epochs, learning rate, update type, update frequency + // basic functionality + //{"TwoNN", 4, 60000, 32, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "false","BALANCED", 200}, + + // One important point is that we do the model averaging in the case of BSP + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "BASELINE", "false", "IMBALANCED", 200}, + {"CNN", 2, 4, 1, 1, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "BASELINE", "false", "IMBALANCED", 200}, + //{"TwoNN", 5, 1000, 100, 1, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE", "true", "BALANCED", 200}, + + /* + // runtime balancing + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MIN", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MIN", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX", "true", "IMBALANCED", 200}, + + // data partitioning + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE_AVG", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE_TO_MAX", "NONE", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE", "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "BALANCE_TO_AVG", "NONE", "true", "IMBALANCED", 200}, + + // balanced tests + {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE", "true", "BALANCED", 200} + */ + }); + } + + public EncryptedFederatedParamservTest(String networkType, int numFederatedWorkers, int dataSetSize, int batch_size, + int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String weighting, String data_distribution, int seed) { + try { + NativeHEHelper.initialize(); + } catch (Exception e) { + throw e; + } + _networkType = networkType; + _numFederatedWorkers = numFederatedWorkers; + _dataSetSize = dataSetSize; + _batch_size = batch_size; + _epochs = epochs; + _eta = eta; + _utype = utype; + _freq = freq; + _scheme = scheme; + _runtime_balancing = runtime_balancing; + _weighting = weighting; + _data_distribution = data_distribution; + _seed = seed; + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Test + public void EncryptedfederatedParamservSingleNode() { + EncryptedfederatedParamserv(ExecMode.SINGLE_NODE, true); + } + + @Test + public void EncryptedfederatedParamservHybrid() { + EncryptedfederatedParamserv(ExecMode.HYBRID, true); + } + + private void EncryptedfederatedParamserv(ExecMode mode, boolean modelAvg) { + // Warning Statistics accumulate in unit test + // config + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + setOutputBuffering(true); + + int C = 1, Hin = 28, Win = 28; + int numLabels = 10; + + ExecMode platformOld = setExecMode(mode); + + try { + // start threads + List<Integer> ports = new ArrayList<>(); + List<Thread> threads = new ArrayList<>(); + for(int i = 0; i < _numFederatedWorkers; i++) { + ports.add(getRandomAvailablePort()); + threads.add(startLocalFedWorkerThread(ports.get(i), + i==(_numFederatedWorkers-1) ? FED_WORKER_WAIT : FED_WORKER_WAIT_S)); + } + + // generate test data + double[][] features = generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win); + double[][] labels = generateDummyMNISTLabels(_dataSetSize, numLabels); + String featuresName = ""; + String labelsName = ""; + + PrivacyConstraint privacyConstraint = new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.Private); + + // federate test data balanced or imbalanced + if(_data_distribution.equals("IMBALANCED")) { + featuresName = "X_IMBALANCED_" + _numFederatedWorkers; + labelsName = "y_IMBALANCED_" + _numFederatedWorkers; + double[][] ranges = {{0,1}, {1,4}}; + rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, ranges, privacyConstraint); + rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, ranges, privacyConstraint); + } + else { + featuresName = "X_BALANCED_" + _numFederatedWorkers; + labelsName = "y_BALANCED_" + _numFederatedWorkers; + double[][] ranges = generateBalancedFederatedRowRanges(_numFederatedWorkers, features.length); + rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, ranges, privacyConstraint); + rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, ranges, privacyConstraint); + } + + try { + //wait for all workers to be setup + Thread.sleep(FED_WORKER_WAIT); + } + catch(InterruptedException e) { + e.printStackTrace(); + } + + // dml name + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + // generate program args + List<String> programArgsList = new ArrayList<>(Arrays.asList("-stats", + "-nvargs", + "features=" + input(featuresName), + "labels=" + input(labelsName), + "epochs=" + _epochs, + "batch_size=" + _batch_size, + "eta=" + _eta, + "utype=" + _utype, + "freq=" + _freq, + "scheme=" + _scheme, + "runtime_balancing=" + _runtime_balancing, + "weighting=" + _weighting, + "network_type=" + _networkType, + "channels=" + C, + "hin=" + Hin, + "win=" + Win, + "seed=" + _seed, + "modelAvg=" + Boolean.toString(modelAvg).toUpperCase())); + + programArgs = programArgsList.toArray(new String[0]); + String log = runTest(null).toString(); + Assert.assertEquals("Test Failed \n" + log, 0, Statistics.getNoOfExecutedSPInst()); + + // shut down threads + for(int i = 0; i < _numFederatedWorkers; i++) { + TestUtils.shutdownThreads(threads.get(i)); + } + } + finally { + resetExecMode(platformOld); + } + } + + /** + * Generates an feature matrix that has the same format as the MNIST dataset, + * but is completely random and normalized + * + * @param numExamples Number of examples to generate + * @param C Channels in the input data + * @param Hin Height in Pixels of the input data + * @param Win Width in Pixels of the input data + * @return a dummy MNIST feature matrix + */ + private double[][] generateDummyMNISTFeatures(int numExamples, int C, int Hin, int Win) { + // Seed -1 takes the time in milliseconds as a seed + // Sparsity 1 means no sparsity + return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1); + } + + /** + * Generates an label matrix that has the same format as the MNIST dataset, but is completely random and consists + * of one hot encoded vectors as rows + * + * @param numExamples Number of examples to generate + * @param numLabels Number of labels to generate + * @return a dummy MNIST lable matrix + */ + private double[][] generateDummyMNISTLabels(int numExamples, int numLabels) { + // Seed -1 takes the time in milliseconds as a seed + // Sparsity 1 means no sparsity + return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/homomorphicEncryption/InOutTest.java b/src/test/java/org/apache/sysds/test/functions/homomorphicEncryption/InOutTest.java new file mode 100644 index 00000000000..58600887c71 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/homomorphicEncryption/InOutTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.homomorphicEncryption; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALClient; +import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALServer; +import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix; +import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataFormat; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class InOutTest extends AutomatedTestBase { + private final static String TEST_NAME = "InOutTest"; + private final static String TEST_DIR = "functions/data/"; + private final static String TEST_CLASS_DIR = TEST_DIR + InOutTest.class.getSimpleName() + "/"; + + private final int num_clients = 3; + + private final int rows = 100; + private final int cols = 200; + private final long seed = 42; + + @Override + public void setUp() { + try { + NativeHEHelper.initialize(); + } catch (Exception e) { + throw e; + } + + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "C" }) ); + } + + @Test + public void endToEndTest() { + SEALServer server = new SEALServer(); + + SEALClient[] clients = new SEALClient[num_clients]; + PublicKey[] partial_pub_keys = new PublicKey[num_clients]; + for (int i = 0; i < num_clients; i++) { + clients[i] = new SEALClient(server.generateA()); + partial_pub_keys[i] = clients[i].generatePartialPublicKey(); + } + + PublicKey public_key = server.aggregatePartialPublicKeys(partial_pub_keys); + + MatrixObject[] plaintexts = new MatrixObject[num_clients]; + CiphertextMatrix[] ciphertexts = new CiphertextMatrix[num_clients]; + for (int i = 0; i < num_clients; i++) { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(rows, cols, -100, 100, 1.0, seed+i); + MatrixObject mo = new MatrixObject(Types.ValueType.FP64, null); + mo.setMetaData(new MetaDataFormat(new MatrixCharacteristics(rows, cols), Types.FileFormat.BINARY)); + mo.acquireModify(mb); + mo.release(); + plaintexts[i] = mo; + + clients[i].setPublicKey(public_key); + ciphertexts[i] = clients[i].encrypt(plaintexts[i]); + } + + CiphertextMatrix encrypted_sum = server.accumulateCiphertexts(ciphertexts); + + PlaintextMatrix[] partial_decryptions = new PlaintextMatrix[num_clients]; + for (int i = 0; i < num_clients; i++) { + partial_decryptions[i] = clients[i].partiallyDecrypt(encrypted_sum); + } + + MatrixObject result = server.average(encrypted_sum, partial_decryptions); + + double[] expected_raw_result = new double[rows*cols]; + double[][] plaintexts_raw = new double[num_clients][]; + for (int i = 0; i < num_clients; i++) { + plaintexts_raw[i] = plaintexts[i].acquireReadAndRelease().getDenseBlockValues(); + } + for (int x = 0; x < rows * cols; x++) { + double sum = 0.0; + for (int i = 0; i < num_clients; i++) { + sum += plaintexts_raw[i][x]; + } + expected_raw_result[x] = sum / num_clients; + } + + double[] raw_result = result.acquireReadAndRelease().getDenseBlockValues(); + assert result.getNumRows() == rows; + assert result.getNumColumns() == cols; + assert raw_result.length == rows*cols; + TestUtils.compareMatrices(raw_result, expected_raw_result, 5e-8); + } +} diff --git a/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml new file mode 100644 index 00000000000..b8021867dc9 --- /dev/null +++ b/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml @@ -0,0 +1,61 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN +source("src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml") as TwoNNModelAvg +source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN +source("src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml") as CNNModelAvg + + +# create federated input matrices +features = read($features) +labels = read($labels) + +if($network_type == "TwoNN") { + if(!as.logical($modelAvg)) { + model = TwoNN::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed) + print("Test results:") + [loss_test, accuracy_test] = TwoNN::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, list()) + print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n") + } + else if (as.logical($modelAvg)){ + model = TwoNNModelAvg::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed, $modelAvg) + print("Test results:") + [loss_test, accuracy_test] = TwoNNModelAvg::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, list()) + print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n") + } +} +else if($network_type == "CNN") { + if(!as.logical($modelAvg)) { + model = CNN::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, $win, $seed) + print("Test results:") + hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win) + [loss_test, accuracy_test] = CNN::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, hyperparams) + print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n") + } + else if (as.logical($modelAvg)){ + model = CNNModelAvg::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, $win, $seed, $modelAvg) + print("Test results:") + hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win) + [loss_test, accuracy_test] = CNNModelAvg::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, hyperparams) + print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n") + } +}