Skip to content

Commit

Permalink
Cut sparse state vector (#1018)
Browse files Browse the repository at this point in the history
* Cut sparse state vector (use QBDD)

* Per #1017 discussion
  • Loading branch information
WrathfulSpatula authored Oct 21, 2024
1 parent 0f73a9a commit 9fa593a
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 537 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake_minimum_required (VERSION 3.9)
project (Qrack VERSION 9.11.11 DESCRIPTION "High Performance Quantum Bit Simulation" LANGUAGES CXX)
project (Qrack VERSION 9.11.12 DESCRIPTION "High Performance Quantum Bit Simulation" LANGUAGES CXX)

# Installation commands
include (GNUInstallDirs)
Expand Down
8 changes: 3 additions & 5 deletions include/qengine_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,12 @@ class QEngineCPU : public QEngine {
DispatchQueue dispatchQueue;
#endif

StateVectorSparsePtr CastStateVecSparse() { return std::dynamic_pointer_cast<StateVectorSparse>(stateVec); }

public:
QEngineCPU(bitLenInt qBitCount, const bitCapInt& initState, qrack_rand_gen_ptr rgp = nullptr,
const complex& phaseFac = CMPLX_DEFAULT_ARG, bool doNorm = false, bool randomGlobalPhase = true,
bool ignored = false, int64_t ignored2 = -1, bool useHardwareRNG = true, bool useSparseStateVec = false,
real1_f norm_thresh = REAL1_EPSILON, std::vector<int64_t> ignored3 = {}, bitLenInt ignored4 = 0U,
real1_f ignored5 = _qrack_qunit_sep_thresh);
bool ignored = false, int64_t ignored2 = -1, bool useHardwareRNG = true, bool ignored3 = false,
real1_f norm_thresh = REAL1_EPSILON, std::vector<int64_t> ignored4 = {}, bitLenInt ignored5 = 0U,
real1_f ignored6 = _qrack_qunit_sep_thresh);

~QEngineCPU() { Dump(); }

Expand Down
332 changes: 0 additions & 332 deletions include/statevector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
#include <future>
#endif

#include <unordered_map>
#define SparseStateVecMap std::unordered_map<bitCapIntOcl, complex>

#if ENABLE_COMPLEX_X2
#if FPPOW == 5
#include "common/complex8x2simd.hpp"
Expand All @@ -36,7 +33,6 @@
namespace Qrack {

class StateVectorArray;
class StateVectorSparse;

// This is a buffer struct that's capable of representing controlled single bit gates and arithmetic, when subclassed.
class StateVector : public ParallelFor {
Expand Down Expand Up @@ -74,7 +70,6 @@ class StateVector : public ParallelFor {
virtual void copy(StateVectorPtr toCopy) = 0;
virtual void shuffle(StateVectorPtr svp) = 0;
virtual void get_probs(real1* outArray) = 0;
virtual bool is_sparse() = 0;
};

class StateVectorArray : public StateVector {
Expand Down Expand Up @@ -217,332 +212,5 @@ class StateVectorArray : public StateVector {
par_for(
0, capacity, [&](const bitCapIntOcl& lcv, const unsigned& cpu) { outArray[lcv] = norm(amplitudes[lcv]); });
}

bool is_sparse() { return false; }
};

class StateVectorSparse : public StateVector {
protected:
SparseStateVecMap amplitudes;
std::mutex mtx;

complex readUnlocked(const bitCapIntOcl& i)
{
auto it = amplitudes.find(i);
return (it == amplitudes.end()) ? ZERO_CMPLX : it->second;
}

complex readLocked(const bitCapIntOcl& i)
{
std::lock_guard<std::mutex> lock(mtx);
return readUnlocked(i);
}

public:
StateVectorSparse(bitCapIntOcl cap)
: StateVector(cap)
, amplitudes()
{
}

complex read(const bitCapIntOcl& i) { return isReadLocked ? readLocked(i) : readUnlocked(i); }

#if ENABLE_COMPLEX_X2
complex2 read2(const bitCapIntOcl& i1, const bitCapIntOcl& i2)
{
if (isReadLocked) {
return complex2(readLocked(i1), readLocked(i2));
}
return complex2(readUnlocked(i1), readUnlocked(i2));
}
#endif

void write(const bitCapIntOcl& i, const complex& c)
{
const bool isCSet = abs(c) > REAL1_EPSILON;
if (isCSet) {
std::lock_guard<std::mutex> lock(mtx);
amplitudes[i] = c;
} else {
std::lock_guard<std::mutex> lock(mtx);
amplitudes.erase(i);
}
}

void write2(const bitCapIntOcl& i1, const complex& c1, const bitCapIntOcl& i2, const complex& c2)
{
const bool isC1Set = abs(c1) > REAL1_EPSILON;
const bool isC2Set = abs(c2) > REAL1_EPSILON;
if (!isC1Set && !isC2Set) {
std::lock_guard<std::mutex> lock(mtx);
amplitudes.erase(i1);
amplitudes.erase(i2);
} else if (isC1Set && isC2Set) {
std::lock_guard<std::mutex> lock(mtx);
amplitudes[i1] = c1;
amplitudes[i2] = c2;
} else if (isC1Set) {
std::lock_guard<std::mutex> lock(mtx);
amplitudes.erase(i2);
amplitudes[i1] = c1;
} else {
std::lock_guard<std::mutex> lock(mtx);
amplitudes.erase(i1);
amplitudes[i2] = c2;
}
}

void clear()
{
std::lock_guard<std::mutex> lock(mtx);
amplitudes.clear();
}

void copy_in(const complex* copyIn)
{
if (!copyIn) {
clear();
return;
}

std::lock_guard<std::mutex> lock(mtx);
for (bitCapIntOcl i = 0U; i < capacity; ++i) {
if (abs(copyIn[i]) <= REAL1_EPSILON) {
amplitudes.erase(i);
} else {
amplitudes[i] = copyIn[i];
}
}
}

void copy_in(const complex* copyIn, const bitCapIntOcl offset, const bitCapIntOcl length)
{
if (!copyIn) {
std::lock_guard<std::mutex> lock(mtx);
for (bitCapIntOcl i = 0U; i < length; ++i) {
amplitudes.erase(i);
}

return;
}

std::lock_guard<std::mutex> lock(mtx);
for (bitCapIntOcl i = 0U; i < length; ++i) {
if (abs(copyIn[i]) <= REAL1_EPSILON) {
amplitudes.erase(i);
} else {
amplitudes[i + offset] = copyIn[i];
}
}
}

void copy_in(
StateVectorPtr copyInSv, const bitCapIntOcl srcOffset, const bitCapIntOcl dstOffset, const bitCapIntOcl length)
{
StateVectorSparsePtr copyIn = std::dynamic_pointer_cast<StateVectorSparse>(copyInSv);

if (!copyIn) {
std::lock_guard<std::mutex> lock(mtx);
for (bitCapIntOcl i = 0U; i < length; ++i) {
amplitudes.erase(i + srcOffset);
}

return;
}

std::lock_guard<std::mutex> lock(mtx);
for (bitCapIntOcl i = 0U; i < length; ++i) {
complex amp = copyIn->read(i + srcOffset);
if (abs(amp) <= REAL1_EPSILON) {
amplitudes.erase(i + srcOffset);
} else {
amplitudes[i + dstOffset] = amp;
}
}
}

void copy_out(complex* copyOut)
{
for (bitCapIntOcl i = 0U; i < capacity; ++i) {
copyOut[i] = read(i);
}
}

void copy_out(complex* copyOut, const bitCapIntOcl offset, const bitCapIntOcl length)
{
for (bitCapIntOcl i = 0U; i < length; ++i) {
copyOut[i] = read(i + offset);
}
}

void copy(const StateVectorPtr toCopy) { copy(std::dynamic_pointer_cast<StateVectorSparse>(toCopy)); }

void copy(StateVectorSparsePtr toCopy)
{
std::lock_guard<std::mutex> lock(mtx);
amplitudes = toCopy->amplitudes;
}

void shuffle(StateVectorPtr svp) { shuffle(std::dynamic_pointer_cast<StateVectorSparse>(svp)); }

void shuffle(StateVectorSparsePtr svp)
{
const size_t halfCap = (size_t)(capacity >> 1U);
std::lock_guard<std::mutex> lock(mtx);
for (bitCapIntOcl i = 0U; i < halfCap; ++i) {
complex amp = svp->read(i);
svp->write(i, read(i + halfCap));
write(i + halfCap, amp);
}
}

void get_probs(real1* outArray)
{
for (bitCapIntOcl i = 0U; i < capacity; ++i) {
outArray[i] = norm(read(i));
}
}

bool is_sparse() { return (amplitudes.size() < (size_t)(capacity >> 1U)); }

std::vector<bitCapIntOcl> iterable()
{
std::vector<std::vector<bitCapIntOcl>> toRet(GetConcurrencyLevel());
std::vector<std::vector<bitCapIntOcl>>::iterator toRetIt;

// For lock_guard scope
if (true) {
std::lock_guard<std::mutex> lock(mtx);

par_for(0U, amplitudes.size(), [&](const bitCapIntOcl& lcv, const unsigned& cpu) {
auto it = amplitudes.begin();
std::advance(it, lcv);
toRet[cpu].push_back(it->first);
});
}

for (int64_t i = (int64_t)(toRet.size() - 1U); i >= 0; i--) {
if (toRet[i].empty()) {
toRetIt = toRet.begin();
std::advance(toRetIt, i);
toRet.erase(toRetIt);
}
}

if (toRet.empty()) {
return {};
}

while (toRet.size() > 1U) {
// Work odd unit into collapse sequence:
if (toRet.size() & 1U) {
toRet[toRet.size() - 2U].insert(
toRet[toRet.size() - 2U].end(), toRet[toRet.size() - 1U].begin(), toRet[toRet.size() - 1U].end());
toRet.pop_back();
}

const int64_t combineCount = (int64_t)(toRet.size() >> 1U);
#if ENABLE_PTHREAD
std::vector<std::future<void>> futures(combineCount);
for (int64_t i = (combineCount - 1); i >= 0; i--) {
futures[i] = std::async(std::launch::async, [i, combineCount, &toRet]() {
toRet[i].insert(toRet[i].end(), toRet[i + combineCount].begin(), toRet[i + combineCount].end());
toRet[i + combineCount].clear();
});
}
for (int64_t i = (combineCount - 1); i >= 0; i--) {
futures[i].get();
toRet.pop_back();
}
#else
for (int64_t i = (combineCount - 1); i >= 0; i--) {
toRet[i].insert(toRet[i].end(), toRet[i + combineCount].begin(), toRet[i + combineCount].end());
toRet.pop_back();
}
#endif
}

return toRet[0U];
}

/// Returns empty if iteration should be over full set, otherwise just the iterable elements:
std::set<bitCapIntOcl> iterable(
const bitCapIntOcl& setMask, const bitCapIntOcl& filterMask = 0, const bitCapIntOcl& filterValues = 0)
{
if (!filterMask && filterValues) {
return {};
}

const bitCapIntOcl unsetMask = ~setMask;

std::vector<std::set<bitCapIntOcl>> toRet(GetConcurrencyLevel());
std::vector<std::set<bitCapIntOcl>>::iterator toRetIt;

// For lock_guard scope
if (true) {
std::lock_guard<std::mutex> lock(mtx);

if (!filterMask && !filterValues) {
par_for(0U, amplitudes.size(), [&](const bitCapIntOcl& lcv, const unsigned& cpu) {
auto it = amplitudes.begin();
std::advance(it, lcv);
toRet[cpu].insert(it->first & unsetMask);
});
} else {
const bitCapIntOcl unfilterMask = ~filterMask;
par_for(0U, amplitudes.size(), [&](const bitCapIntOcl lcv, const unsigned& cpu) {
auto it = amplitudes.begin();
std::advance(it, lcv);
if ((it->first & filterMask) == filterValues) {
toRet[cpu].insert(it->first & unsetMask & unfilterMask);
}
});
}
}

for (int64_t i = (int64_t)(toRet.size() - 1U); i >= 0; i--) {
if (toRet[i].empty()) {
toRetIt = toRet.begin();
std::advance(toRetIt, i);
toRet.erase(toRetIt);
}
}

if (toRet.empty()) {
return {};
}

while (toRet.size() > 1U) {
// Work odd unit into collapse sequence:
if (toRet.size() & 1U) {
toRet[toRet.size() - 2U].insert(toRet[toRet.size() - 1U].begin(), toRet[toRet.size() - 1U].end());
toRet.pop_back();
}

const int64_t combineCount = (int64_t)(toRet.size() >> 1U);
#if ENABLE_PTHREAD
std::vector<std::future<void>> futures(combineCount);
for (int64_t i = (combineCount - 1); i >= 0; i--) {
futures[i] = std::async(std::launch::async, [i, combineCount, &toRet]() {
toRet[i].insert(toRet[i + combineCount].begin(), toRet[i + combineCount].end());
toRet[i + combineCount].clear();
});
}

for (int64_t i = (combineCount - 1); i >= 0; i--) {
futures[i].get();
toRet.pop_back();
}
#else
for (int64_t i = (combineCount - 1); i >= 0; i--) {
toRet[i].insert(toRet[i + combineCount].begin(), toRet[i + combineCount].end());
toRet.pop_back();
}
#endif
}

return toRet[0U];
}
};

} // namespace Qrack
Loading

0 comments on commit 9fa593a

Please sign in to comment.