Skip to content

Commit

Permalink
optimize keyswitching (#137)
Browse files Browse the repository at this point in the history
* optimize keyswitching

* added comments
  • Loading branch information
GelilaSeifu authored Dec 5, 2022
1 parent 311fe72 commit 2d196fd
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 14 deletions.
25 changes: 11 additions & 14 deletions hexl/experimental/seal/key-switch-internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
#include "hexl/eltwise/eltwise-fma-mod.hpp"
#include "hexl/eltwise/eltwise-mult-mod.hpp"
#include "hexl/eltwise/eltwise-reduce-mod.hpp"
#include "hexl/experimental/seal/ntt-cache.hpp"
#include "hexl/logging/logging.hpp"
#include "hexl/ntt/ntt.hpp"
#include "hexl/number-theory/number-theory.hpp"
#include "hexl/util/aligned-allocator.hpp"
#include "hexl/util/check.hpp"
Expand All @@ -36,22 +36,20 @@ void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n,
uint64_t coeff_count = n;

// Create a copy of target_iter
std::vector<uint64_t> t_target(coeff_count * decomp_modulus_size, 0);
for (size_t i = 0; i < coeff_count * decomp_modulus_size; ++i) {
t_target[i] = t_target_iter_ptr[i];
}

uint64_t* t_target_ptr = &t_target[0];
std::vector<uint64_t> t_target(
t_target_iter_ptr,
t_target_iter_ptr + (coeff_count * decomp_modulus_size));
uint64_t* t_target_ptr = t_target.data();

// Simplified implementation, where we assume no modular reduction is required
// for intermediate additions
std::vector<uint64_t> t_ntt(coeff_count, 0);
uint64_t* t_ntt_ptr = &t_ntt[0];
uint64_t* t_ntt_ptr = t_ntt.data();

// In CKKS t_target is in NTT form; switch
// back to normal form
for (size_t j = 0; j < decomp_modulus_size; ++j) {
NTT(n, moduli[j])
GetNTT(n, moduli[j])
.ComputeInverse(&t_target_ptr[j * coeff_count],
&t_target_ptr[j * coeff_count], 2, 1);
}
Expand Down Expand Up @@ -87,8 +85,7 @@ void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n,
}

// NTT conversion lazy outputs in [0, 4q)
NTT(n, moduli[key_index]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4);

GetNTT(n, moduli[key_index]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4);
t_operand = t_ntt_ptr;
}

Expand Down Expand Up @@ -141,7 +138,8 @@ void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n,
&t_poly_prod[key_component * coeff_count * rns_modulus_size];
uint64_t* t_last = &t_poly_prod_it[decomp_modulus_size * coeff_count];

NTT(n, moduli[key_modulus_size - 1]).ComputeInverse(t_last, t_last, 2, 2);
GetNTT(n, moduli[key_modulus_size - 1])
.ComputeInverse(t_last, t_last, 2, 2);

uint64_t qk = moduli[key_modulus_size - 1];
uint64_t qk_half = qk >> 1;
Expand Down Expand Up @@ -178,8 +176,7 @@ void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n,
}

uint64_t qi_lazy = qi << 1; // some multiples of qi

NTT(n, moduli[i]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4);
GetNTT(n, moduli[i]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4);
// Since SEAL uses at most 60bit moduli, 8*qi < 2^63.
qi_lazy = qi << 2;

Expand Down
35 changes: 35 additions & 0 deletions hexl/include/hexl/experimental/seal/locks.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <mutex>
#include <shared_mutex>

namespace intel {
namespace hexl {

using Lock = std::shared_mutex;
using WriteLock = std::unique_lock<Lock>;
using ReadLock = std::shared_lock<Lock>;

class RWLock {
public:
RWLock() = default;
inline ReadLock AcquireRead() { return ReadLock(rw_mutex); }
inline WriteLock AcquireWrite() { return WriteLock(rw_mutex); }
inline ReadLock TryAcquireRead() noexcept {
return ReadLock(rw_mutex, std::try_to_lock);
}
inline WriteLock TryAcquireWrite() noexcept {
return WriteLock(rw_mutex, std::try_to_lock);
}

private:
RWLock(const RWLock& copy) = delete;
RWLock& operator=(const RWLock& assign) = delete;
Lock rw_mutex{};
};

} // namespace hexl
} // namespace intel
56 changes: 56 additions & 0 deletions hexl/include/hexl/experimental/seal/ntt-cache.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "hexl/experimental/seal/locks.hpp"
#include "ntt/ntt-internal.hpp"

namespace intel {
namespace hexl {

struct HashPair {
template <class T1, class T2>
std::size_t operator()(const std::pair<T1, T2>& p) const {
auto hash1 = std::hash<T1>{}(p.first);
auto hash2 = std::hash<T2>{}(p.second);
return hash_combine(hash1, hash2);
}

// Golden Ratio Hashing with seeds
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};

NTT& GetNTT(size_t N, uint64_t modulus) {
static std::unordered_map<std::pair<uint64_t, uint64_t>, NTT, HashPair>
ntt_cache;
static RWLock ntt_cache_locker;

std::pair<uint64_t, uint64_t> key{N, modulus};

// Enable shared access to NTT already present
{
ReadLock reader_lock(ntt_cache_locker.AcquireRead());
auto ntt_it = ntt_cache.find(key);
if (ntt_it != ntt_cache.end()) {
return ntt_it->second;
}
}

// Deal with NTT not yet present
WriteLock write_lock(ntt_cache_locker.AcquireWrite());

// Check ntt_cache for value (may be added by another thread)
auto ntt_it = ntt_cache.find(key);
if (ntt_it == ntt_cache.end()) {
NTT ntt(N, modulus);
ntt_it = ntt_cache.emplace(std::move(key), std::move(ntt)).first;
}
return ntt_it->second;
}

} // namespace hexl
} // namespace intel

0 comments on commit 2d196fd

Please sign in to comment.