From 2d196fdd71f24511bd7e0e23dc07d37c888f53e7 Mon Sep 17 00:00:00 2001 From: Gelila Seifu Date: Mon, 5 Dec 2022 11:39:01 -0600 Subject: [PATCH] optimize keyswitching (#137) * optimize keyswitching * added comments --- .../experimental/seal/key-switch-internal.cpp | 25 ++++----- hexl/include/hexl/experimental/seal/locks.hpp | 35 ++++++++++++ .../hexl/experimental/seal/ntt-cache.hpp | 56 +++++++++++++++++++ 3 files changed, 102 insertions(+), 14 deletions(-) create mode 100644 hexl/include/hexl/experimental/seal/locks.hpp create mode 100644 hexl/include/hexl/experimental/seal/ntt-cache.hpp diff --git a/hexl/experimental/seal/key-switch-internal.cpp b/hexl/experimental/seal/key-switch-internal.cpp index 04477979..15edb9a8 100644 --- a/hexl/experimental/seal/key-switch-internal.cpp +++ b/hexl/experimental/seal/key-switch-internal.cpp @@ -11,8 +11,8 @@ #include "hexl/eltwise/eltwise-fma-mod.hpp" #include "hexl/eltwise/eltwise-mult-mod.hpp" #include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/experimental/seal/ntt-cache.hpp" #include "hexl/logging/logging.hpp" -#include "hexl/ntt/ntt.hpp" #include "hexl/number-theory/number-theory.hpp" #include "hexl/util/aligned-allocator.hpp" #include "hexl/util/check.hpp" @@ -36,22 +36,20 @@ void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, uint64_t coeff_count = n; // Create a copy of target_iter - std::vector t_target(coeff_count * decomp_modulus_size, 0); - for (size_t i = 0; i < coeff_count * decomp_modulus_size; ++i) { - t_target[i] = t_target_iter_ptr[i]; - } - - uint64_t* t_target_ptr = &t_target[0]; + std::vector t_target( + t_target_iter_ptr, + t_target_iter_ptr + (coeff_count * decomp_modulus_size)); + uint64_t* t_target_ptr = t_target.data(); // Simplified implementation, where we assume no modular reduction is required // for intermediate additions std::vector t_ntt(coeff_count, 0); - uint64_t* t_ntt_ptr = &t_ntt[0]; + uint64_t* t_ntt_ptr = t_ntt.data(); // In CKKS t_target is in NTT form; switch // back to normal form for (size_t j = 0; j < decomp_modulus_size; ++j) { - NTT(n, moduli[j]) + GetNTT(n, moduli[j]) .ComputeInverse(&t_target_ptr[j * coeff_count], &t_target_ptr[j * coeff_count], 2, 1); } @@ -87,8 +85,7 @@ void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, } // NTT conversion lazy outputs in [0, 4q) - NTT(n, moduli[key_index]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4); - + GetNTT(n, moduli[key_index]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4); t_operand = t_ntt_ptr; } @@ -141,7 +138,8 @@ void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, &t_poly_prod[key_component * coeff_count * rns_modulus_size]; uint64_t* t_last = &t_poly_prod_it[decomp_modulus_size * coeff_count]; - NTT(n, moduli[key_modulus_size - 1]).ComputeInverse(t_last, t_last, 2, 2); + GetNTT(n, moduli[key_modulus_size - 1]) + .ComputeInverse(t_last, t_last, 2, 2); uint64_t qk = moduli[key_modulus_size - 1]; uint64_t qk_half = qk >> 1; @@ -178,8 +176,7 @@ void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, } uint64_t qi_lazy = qi << 1; // some multiples of qi - - NTT(n, moduli[i]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4); + GetNTT(n, moduli[i]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4); // Since SEAL uses at most 60bit moduli, 8*qi < 2^63. qi_lazy = qi << 2; diff --git a/hexl/include/hexl/experimental/seal/locks.hpp b/hexl/include/hexl/experimental/seal/locks.hpp new file mode 100644 index 00000000..4595f4e5 --- /dev/null +++ b/hexl/include/hexl/experimental/seal/locks.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace intel { +namespace hexl { + +using Lock = std::shared_mutex; +using WriteLock = std::unique_lock; +using ReadLock = std::shared_lock; + +class RWLock { + public: + RWLock() = default; + inline ReadLock AcquireRead() { return ReadLock(rw_mutex); } + inline WriteLock AcquireWrite() { return WriteLock(rw_mutex); } + inline ReadLock TryAcquireRead() noexcept { + return ReadLock(rw_mutex, std::try_to_lock); + } + inline WriteLock TryAcquireWrite() noexcept { + return WriteLock(rw_mutex, std::try_to_lock); + } + + private: + RWLock(const RWLock& copy) = delete; + RWLock& operator=(const RWLock& assign) = delete; + Lock rw_mutex{}; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl/include/hexl/experimental/seal/ntt-cache.hpp b/hexl/include/hexl/experimental/seal/ntt-cache.hpp new file mode 100644 index 00000000..8f6c1046 --- /dev/null +++ b/hexl/include/hexl/experimental/seal/ntt-cache.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/seal/locks.hpp" +#include "ntt/ntt-internal.hpp" + +namespace intel { +namespace hexl { + +struct HashPair { + template + std::size_t operator()(const std::pair& p) const { + auto hash1 = std::hash{}(p.first); + auto hash2 = std::hash{}(p.second); + return hash_combine(hash1, hash2); + } + + // Golden Ratio Hashing with seeds + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; + +NTT& GetNTT(size_t N, uint64_t modulus) { + static std::unordered_map, NTT, HashPair> + ntt_cache; + static RWLock ntt_cache_locker; + + std::pair key{N, modulus}; + + // Enable shared access to NTT already present + { + ReadLock reader_lock(ntt_cache_locker.AcquireRead()); + auto ntt_it = ntt_cache.find(key); + if (ntt_it != ntt_cache.end()) { + return ntt_it->second; + } + } + + // Deal with NTT not yet present + WriteLock write_lock(ntt_cache_locker.AcquireWrite()); + + // Check ntt_cache for value (may be added by another thread) + auto ntt_it = ntt_cache.find(key); + if (ntt_it == ntt_cache.end()) { + NTT ntt(N, modulus); + ntt_it = ntt_cache.emplace(std::move(key), std::move(ntt)).first; + } + return ntt_it->second; +} + +} // namespace hexl +} // namespace intel