Skip to content

Commit

Permalink
Fix FP128 builds
Browse files Browse the repository at this point in the history
  • Loading branch information
WrathfulSpatula committed Jun 15, 2023
1 parent 29a3d81 commit a96a133
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 21 deletions.
8 changes: 4 additions & 4 deletions include/qneuron.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ class QNeuron {
std::unique_ptr<real1[]> angles;
QInterfacePtr qReg;

static real1_f applyRelu(real1_f angle) { return std::max(ZERO_R1_F, angle); }
static real1_f applyRelu(real1_f angle) { return std::max((real1_f)ZERO_R1_F, (real1_f)angle); }

static real1_f negApplyRelu(real1_f angle) { return -std::max(ZERO_R1_F, angle); }
static real1_f negApplyRelu(real1_f angle) { return -std::max((real1_f)ZERO_R1_F, (real1_f)angle); }

static real1_f applyGelu(real1_f angle) { return angle * (1 + erf(angle * SQRT1_2_R1)); }
static real1_f applyGelu(real1_f angle) { return angle * (1 + erf((real1_s)(angle * SQRT1_2_R1))); }

static real1_f negApplyGelu(real1_f angle) { return -angle * (1 + erf(angle * SQRT1_2_R1)); }
static real1_f negApplyGelu(real1_f angle) { return -angle * (1 + erf((real1_s)(angle * SQRT1_2_R1))); }

static real1_f applyAlpha(real1_f angle, real1_f alpha)
{
Expand Down
4 changes: 2 additions & 2 deletions include/qunit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,10 @@ class QUnit : public QParity, public QInterface {
}

if (IS_NORM_0(shard.amp1)) {
logFidelity += log(clampProb(ONE_R1_F - norm(shard.amp1)));
logFidelity += (double)log(clampProb(ONE_R1_F - norm(shard.amp1)));
SeparateBit(false, qubit);
} else if (IS_NORM_0(shard.amp0)) {
logFidelity += log(clampProb(ONE_R1_F - norm(shard.amp0)));
logFidelity += (double)log(clampProb(ONE_R1_F - norm(shard.amp0)));
SeparateBit(true, qubit);
}
}
Expand Down
20 changes: 18 additions & 2 deletions src/pinvoke_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2897,20 +2897,32 @@ MICROSOFT_QUANTUM_DECL void destroy_qneuron(_In_ uintq nid)
neuronReservations[nid] = false;
}

MICROSOFT_QUANTUM_DECL void set_qneuron_angles(_In_ uintq nid, _In_ real1_f* angles)
#if FPPOW < 6
MICROSOFT_QUANTUM_DECL void set_qneuron_angles(_In_ uintq nid, _In_ float* angles)
#else
MICROSOFT_QUANTUM_DECL void set_qneuron_angles(_In_ uintq nid, _In_ double* angles)
#endif
{
NEURON_LOCK_GUARD_VOID(nid)
#if (FPPOW == 5) || (FPPOW == 6)
neuron->SetAngles(angles);
#else
const bitCapIntOcl inputPower = (bitCapIntOcl)neuron->GetInputPower();
std::unique_ptr<real1[]> _angles(new real1[inputPower]);
#if (FPPOW == 4)
std::copy(angles, angles + inputPower, _angles.get());
#else
std::transform(angles, angles + inputPower, _angles.get(), [](double d) { return (real1)d; });
#endif
neuron->SetAngles(_angles.get());
#endif
}

MICROSOFT_QUANTUM_DECL void get_qneuron_angles(_In_ uintq nid, _In_ real1_f* angles)
#if FPPOW < 6
MICROSOFT_QUANTUM_DECL void get_qneuron_angles(_In_ uintq nid, _In_ float* angles)
#else
MICROSOFT_QUANTUM_DECL void get_qneuron_angles(_In_ uintq nid, _In_ double* angles)
#endif
{
NEURON_LOCK_GUARD_VOID(nid)
#if (FPPOW == 5) || (FPPOW == 6)
Expand All @@ -2919,7 +2931,11 @@ MICROSOFT_QUANTUM_DECL void get_qneuron_angles(_In_ uintq nid, _In_ real1_f* ang
const bitCapIntOcl inputPower = (bitCapIntOcl)neuron->GetInputPower();
std::unique_ptr<real1[]> _angles(new real1[inputPower]);
neuron->GetAngles(_angles.get());
#if (FPPOW == 4)
std::copy(_angles.get(), _angles.get() + inputPower, angles);
#else
std::transform(_angles.get(), _angles.get() + inputPower, angles, [](real1 d) { return (double)d; });
#endif
#endif
}

Expand Down
22 changes: 11 additions & 11 deletions src/qunit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,22 +129,22 @@ void QUnit::SetQuantumState(const complex* inputState)
shard.amp1 = inputState[1U];
shard.pauliBasis = PauliZ;
if (IS_AMP_0(shard.amp0 - shard.amp1)) {
logFidelity += log(clampProb(ONE_R1_F - norm(shard.amp0 - shard.amp1)));
logFidelity += (double)log(clampProb(ONE_R1_F - norm(shard.amp0 - shard.amp1)));
shard.pauliBasis = PauliX;
shard.amp0 = shard.amp0 / abs(shard.amp0);
shard.amp1 = ZERO_R1;
} else if (IS_AMP_0(shard.amp0 + shard.amp1)) {
logFidelity += log(clampProb(ONE_R1_F - norm(shard.amp0 + shard.amp1)));
logFidelity += (double)log(clampProb(ONE_R1_F - norm(shard.amp0 + shard.amp1)));
shard.pauliBasis = PauliX;
shard.amp1 = shard.amp0 / abs(shard.amp0);
shard.amp0 = ZERO_R1;
} else if (IS_AMP_0((I_CMPLX * inputState[0U]) - inputState[1U])) {
logFidelity += log(clampProb(ONE_R1_F - norm((I_CMPLX * inputState[0U]) - inputState[1U])));
logFidelity += (double)log(clampProb(ONE_R1_F - norm((I_CMPLX * inputState[0U]) - inputState[1U])));
shard.pauliBasis = PauliY;
shard.amp0 = shard.amp0 / abs(shard.amp0);
shard.amp1 = ZERO_R1;
} else if (IS_AMP_0((I_CMPLX * inputState[0U]) + inputState[1U])) {
logFidelity += log(clampProb(ONE_R1_F - norm((I_CMPLX * inputState[0U]) - inputState[1U])));
logFidelity += (double)log(clampProb(ONE_R1_F - norm((I_CMPLX * inputState[0U]) - inputState[1U])));
shard.pauliBasis = PauliY;
shard.amp1 = shard.amp0 / abs(shard.amp0);
shard.amp0 = ZERO_R1;
Expand Down Expand Up @@ -766,7 +766,7 @@ bool QUnit::TrySeparate(bitLenInt qubit)
SeparateBit(false, qubit);
ShardAI(qubit, azimuth, inclination);

logFidelity += log(clampProb(1.0 - oneMinR / 2));
logFidelity += (double)log(clampProb(1.0 - oneMinR / 2));

return true;
}
Expand Down Expand Up @@ -976,22 +976,22 @@ real1_f QUnit::ProbBase(bitLenInt qubit)
shard.unit->GetQuantumState(amps);

if (IS_AMP_0(amps[0U] - amps[1U])) {
logFidelity += log(clampProb(ONE_R1_F - norm(amps[0U] - amps[1U])));
logFidelity += (double)log(clampProb(ONE_R1_F - norm(amps[0U] - amps[1U])));
shard.pauliBasis = PauliX;
amps[0U] = amps[0U] / abs(amps[0U]);
amps[1U] = ZERO_CMPLX;
} else if (IS_AMP_0(amps[0U] + amps[1U])) {
logFidelity += log(clampProb(ONE_R1_F - norm(amps[0U] + amps[1U])));
logFidelity += (double)log(clampProb(ONE_R1_F - norm(amps[0U] + amps[1U])));
shard.pauliBasis = PauliX;
amps[1U] = amps[0U] / abs(amps[0U]);
amps[0U] = ZERO_CMPLX;
} else if (IS_AMP_0((I_CMPLX * amps[0U]) - amps[1U])) {
logFidelity += log(clampProb(ONE_R1_F - norm((I_CMPLX * amps[0U]) - amps[1U])));
logFidelity += (double)log(clampProb(ONE_R1_F - norm((I_CMPLX * amps[0U]) - amps[1U])));
shard.pauliBasis = PauliY;
amps[0U] = amps[0U] / abs(amps[0U]);
amps[1U] = ZERO_CMPLX;
} else if (IS_AMP_0((I_CMPLX * amps[0U]) + amps[1U])) {
logFidelity += log(clampProb(ONE_R1_F - norm((I_CMPLX * amps[0U]) + amps[1U])));
logFidelity += (double)log(clampProb(ONE_R1_F - norm((I_CMPLX * amps[0U]) + amps[1U])));
shard.pauliBasis = PauliY;
amps[1U] = amps[0U] / abs(amps[0U]);
amps[0U] = ZERO_CMPLX;
Expand Down Expand Up @@ -1019,10 +1019,10 @@ real1_f QUnit::ProbBase(bitLenInt qubit)
}

if (IS_NORM_0(shard.amp1)) {
logFidelity += log(clampProb(ONE_R1_F - norm(shard.amp1)));
logFidelity += (double)log(clampProb(ONE_R1_F - norm(shard.amp1)));
SeparateBit(false, qubit);
} else if (IS_NORM_0(shard.amp0)) {
logFidelity += log(clampProb(ONE_R1_F - norm(shard.amp0)));
logFidelity += (double)log(clampProb(ONE_R1_F - norm(shard.amp0)));
SeparateBit(true, qubit);
}

Expand Down
2 changes: 1 addition & 1 deletion test/benchmarks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4517,7 +4517,7 @@ real1_f diophantine_fidelity_correction(real1_f sigmoid, real1_f sdrp)
// Reverse variance normalization:
sigmoid = pow(sigmoid, 1 / (1 - sqrt(sdrp)));

if (std::isnan(sigmoid)) {
if (std::isnan((real1_s)sigmoid)) {
return 0;
}

Expand Down
3 changes: 2 additions & 1 deletion test/tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3346,7 +3346,8 @@ TEST_CASE_METHOD(QInterfaceTestFixture, "test_getamplitude")
qftReg->H(0);
qftReg->T(0);
REQUIRE(abs(norm(qftReg->GetAmplitude(0x00)) - 0.5f) < 0.01);
REQUIRE(norm(qftReg->GetAmplitude(0x00) + complex(SQRT1_2_R1, SQRT1_2_R1) * I_CMPLX * qftReg->GetAmplitude(0x01)) < 0.01);
REQUIRE(norm(qftReg->GetAmplitude(0x00) + complex(SQRT1_2_R1, SQRT1_2_R1) * I_CMPLX * qftReg->GetAmplitude(0x01)) <
0.01);
}

TEST_CASE_METHOD(QInterfaceTestFixture, "test_getquantumstate")
Expand Down

0 comments on commit a96a133

Please sign in to comment.