Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Method fft() faster about 10% #76

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ private static ComplexVector calculateRootsOfUnity(int n) {
* {@code roots[s][k] = e^(pi*k*i/(2*roots.length))},
* i.e., they must cover the first quadrant.
*/
private static void fft(ComplexVector a, ComplexVector[] roots) {
static void fftOriginal(ComplexVector a, ComplexVector[] roots) {
int n = a.length;
int logN = 31 - Integer.numberOfLeadingZeros(n);
MutableComplex a0 = new MutableComplex();
Expand Down Expand Up @@ -211,6 +211,294 @@ private static void fft(ComplexVector a, ComplexVector[] roots) {
}
}

// do one final radix-2 step if there is an odd number of stages
if (s > 0) {
for (int i = 0; i < n; i += 2) {
// omega = 1

// a0 = a[i];
// a1 = a[i + 1];
// a[i] += a1;
// a[i + 1] = a0 - a1;
a.copyInto(i, a0);
a.copyInto(i + 1, a1);
a.add(i, a1);
a0.subtractInto(a1, a, i + 1);
}
}
}

static void fft(ComplexVector a, ComplexVector[] roots) {
fftOptimizedLessVariables(a, roots);
}

static void fftOptimizedLessVariables(ComplexVector a, ComplexVector[] roots) {
int n = a.length();
int logN = 31 - Integer.numberOfLeadingZeros(n);
MutableComplex a0 = new MutableComplex();
MutableComplex a1 = new MutableComplex();
MutableComplex a2 = new MutableComplex();
MutableComplex a3 = new MutableComplex();

// do two FFT stages at a time (radix-4)
MutableComplex omega1 = new MutableComplex();
MutableComplex omega2 = new MutableComplex();
MutableComplex e = new MutableComplex();
MutableComplex h = new MutableComplex();

int s = logN;
for (; s >= 2; s -= 2) {
ComplexVector rootsS = roots[s - 2];
int m = 1 << s;
for (int i = 0; i < n; i += m) {
final int m4 = m / 4;
for (int j = 0; j < m4; j++) {
omega1.set(rootsS, j);
// computing omega2 from omega1 is less accurate than Math.cos() and Math.sin(),
// but it is the same error we'd incur with radix-2, so we're not breaking the
// assumptions of the Percival paper.
omega1.squareInto(omega2);

int index = i + j;
int idx0 = index;
index += m4;
int idx1 = index;
index += m4;
int idx2 = index;
index += m4;
int idx3 = index;

// radix-4 butterfly:
// a[idx0] = (a[idx0] + a[idx1] + a[idx2] + a[idx3]) * w^0
// a[idx1] = (a[idx0] + a[idx1]*(-i) + a[idx2]*(-1) + a[idx3]*i) * w^1
// a[idx2] = (a[idx0] + a[idx1]*(-1) + a[idx2] + a[idx3]*(-1)) * w^2
// a[idx3] = (a[idx0] + a[idx1]*i + a[idx2]*(-1) + a[idx3]*(-i)) * w^3
// where w = omega1^(-1) = conjugate(omega1)
// can be reordered to
// a[idx0] = (a[idx0] + a[idx2]) + (a[idx1] + a[idx3])) * w^0
// a[idx1] = (a[idx0] - a[idx2]) - (i)*(a[idx1] - a[idx3])) * w^1
// a[idx2] = (a[idx0] + a[idx2]) - (a[idx1] + a[idx3])) * w^2
// a[idx3] = (a[idx0] - a[idx2]) + (i)*(a[idx1] - a[idx3])) * w^3
// we define
// e = (a[idx0] + a[idx2]); f = (a[idx1] + a[idx3]);
// g = (a[idx0] - a[idx2]); h = (a[idx1] - a[idx3]);
a.addInto(idx0, a, idx2, e);
a.subtractInto(idx0, a, idx2, a3);
a.addInto(idx1, a, idx3, a2);
a.subtractInto(idx1, a, idx3, h);

// original equation after substitution (a2 ~ f, a3 ~ g)
// a[idx0] = (e + f) * w^0
// a[idx1] = (g - h) * w^1
// a[idx2] = (e - f) * w^2
// a[idx3] = (g + h) * w^3
e.addInto(a2, a0);

a3.subtractTimesIInto(h, a1);
a1.multiplyConjugate(omega1);

e.subtractInto(a2, a2);
a2.multiplyConjugate(omega2);

a3.addTimesIInto(h, a3);
a3.multiply(omega1); // Bernstein's trick: multiply by omega^(-1) instead of omega^3

a0.copyInto(a, idx0);
a1.copyInto(a, idx1);
a2.copyInto(a, idx2);
a3.copyInto(a, idx3);
}
}
}

// do one final radix-2 step if there is an odd number of stages
if (s > 0) {
for (int i = 0; i < n; i += 2) {
// omega = 1

// a0 = a[i];
// a1 = a[i + IMAG];
// a[i] += a1;
// a[i + IMAG] = a0 - a1;
a.copyInto(i, a0);
a.copyInto(i + ComplexVector.IMAG, a1);
a.add(i, a1);
a0.subtractInto(a1, a, i + 1);
}
}
}

static void fftOptimizedLessVariablesOriginalIndex(ComplexVector a, ComplexVector[] roots) {
int n = a.length();
int logN = 31 - Integer.numberOfLeadingZeros(n);
MutableComplex a0 = new MutableComplex();
MutableComplex a1 = new MutableComplex();
MutableComplex a2 = new MutableComplex();
MutableComplex a3 = new MutableComplex();

// do two FFT stages at a time (radix-4)
MutableComplex omega1 = new MutableComplex();
MutableComplex omega2 = new MutableComplex();
MutableComplex e = new MutableComplex();
MutableComplex h = new MutableComplex();

int s = logN;
for (; s >= 2; s -= 2) {
ComplexVector rootsS = roots[s - 2];
int m = 1 << s;
for (int i = 0; i < n; i += m) {
for (int j = 0; j < m / 4; j++) {
omega1.set(rootsS, j);
// computing omega2 from omega1 is less accurate than Math.cos() and Math.sin(),
// but it is the same error we'd incur with radix-2, so we're not breaking the
// assumptions of the Percival paper.
omega1.squareInto(omega2);

int idx0 = i + j;
int idx1 = i + j + m / 4;
int idx2 = i + j + m / 2;
int idx3 = i + j + m * 3 / 4;

// radix-4 butterfly:
// a[idx0] = (a[idx0] + a[idx1] + a[idx2] + a[idx3]) * w^0
// a[idx1] = (a[idx0] + a[idx1]*(-i) + a[idx2]*(-1) + a[idx3]*i) * w^1
// a[idx2] = (a[idx0] + a[idx1]*(-1) + a[idx2] + a[idx3]*(-1)) * w^2
// a[idx3] = (a[idx0] + a[idx1]*i + a[idx2]*(-1) + a[idx3]*(-i)) * w^3
// where w = omega1^(-1) = conjugate(omega1)
// can be reordered to
// a[idx0] = (a[idx0] + a[idx2]) + (a[idx1] + a[idx3])) * w^0
// a[idx1] = (a[idx0] - a[idx2]) - (i)*(a[idx1] - a[idx3])) * w^1
// a[idx2] = (a[idx0] + a[idx2]) - (a[idx1] + a[idx3])) * w^2
// a[idx3] = (a[idx0] - a[idx2]) + (i)*(a[idx1] - a[idx3])) * w^3
// we define
// e = (a[idx0] + a[idx2]); f = (a[idx1] + a[idx3]);
// g = (a[idx0] - a[idx2]); h = (a[idx1] - a[idx3]);
a.addInto(idx0, a, idx2, e);
a.subtractInto(idx0, a, idx2, a3);
a.addInto(idx1, a, idx3, a2);
a.subtractInto(idx1, a, idx3, h);

// original equation after substitution (a2 ~ f, a3 ~ g)
// a[idx0] = (e + f) * w^0
// a[idx1] = (g - h) * w^1
// a[idx2] = (e - f) * w^2
// a[idx3] = (g + h) * w^3
e.addInto(a2, a0);

a3.subtractTimesIInto(h, a1);
a1.multiplyConjugate(omega1);

e.subtractInto(a2, a2);
a2.multiplyConjugate(omega2);

a3.addTimesIInto(h, a3);
a3.multiply(omega1); // Bernstein's trick: multiply by omega^(-1) instead of omega^3

a0.copyInto(a, idx0);
a1.copyInto(a, idx1);
a2.copyInto(a, idx2);
a3.copyInto(a, idx3);
}
}
}

// do one final radix-2 step if there is an odd number of stages
if (s > 0) {
for (int i = 0; i < n; i += 2) {
// omega = 1

// a0 = a[i];
// a1 = a[i + IMAG];
// a[i] += a1;
// a[i + IMAG] = a0 - a1;
a.copyInto(i, a0);
a.copyInto(i + ComplexVector.IMAG, a1);
a.add(i, a1);
a0.subtractInto(a1, a, i + 1);
}
}
}
static void fftOptimized(ComplexVector a, ComplexVector[] roots) {
int n = a.length();
int logN = 31 - Integer.numberOfLeadingZeros(n);
MutableComplex a0 = new MutableComplex();
MutableComplex a1 = new MutableComplex();
MutableComplex a2 = new MutableComplex();
MutableComplex a3 = new MutableComplex();

// do two FFT stages at a time (radix-4)
MutableComplex omega1 = new MutableComplex();
MutableComplex omega2 = new MutableComplex();
MutableComplex e = new MutableComplex();
MutableComplex f = new MutableComplex();
MutableComplex g = new MutableComplex();
MutableComplex h = new MutableComplex();

int s = logN;
for (; s >= 2; s -= 2) {
ComplexVector rootsS = roots[s - 2];
int m = 1 << s;
for (int i = 0; i < n; i += m) {
final int m4 = m / 4;
for (int j = 0; j < m4; j++) {
omega1.set(rootsS, j);
// computing omega2 from omega1 is less accurate than Math.cos() and Math.sin(),
// but it is the same error we'd incur with radix-2, so we're not breaking the
// assumptions of the Percival paper.
omega1.squareInto(omega2);

int index = i + j;
int idx0 = index;
index += m4;
int idx1 = index;
index += m4;
int idx2 = index;
index += m4;
int idx3 = index;

// radix-4 butterfly:
// a[idx0] = (a[idx0] + a[idx1] + a[idx2] + a[idx3]) * w^0
// a[idx1] = (a[idx0] + a[idx1]*(-i) + a[idx2]*(-1) + a[idx3]*i) * w^1
// a[idx2] = (a[idx0] + a[idx1]*(-1) + a[idx2] + a[idx3]*(-1)) * w^2
// a[idx3] = (a[idx0] + a[idx1]*i + a[idx2]*(-1) + a[idx3]*(-i)) * w^3
// where w = omega1^(-1) = conjugate(omega1)
// can be reordered to
// a[idx0] = (a[idx0] + a[idx2]) + (a[idx1] + a[idx3])) * w^0
// a[idx1] = (a[idx0] - a[idx2]) - (i)*(a[idx1] - a[idx3])) * w^1
// a[idx2] = (a[idx0] + a[idx2]) - (a[idx1] + a[idx3])) * w^2
// a[idx3] = (a[idx0] - a[idx2]) + (i)*(a[idx1] - a[idx3])) * w^3
// we define
// e = (a[idx0] + a[idx2]); f = (a[idx1] + a[idx3]);
// g = (a[idx0] - a[idx2]); h = (a[idx1] - a[idx3]);
a.addInto(idx0, a, idx2, e);
a.subtractInto(idx0, a, idx2, g);
a.addInto(idx1, a, idx3, f);
a.subtractInto(idx1, a, idx3, h);

// original equation after substitution
// a[idx0] = (e + f) * w^0
// a[idx1] = (g - h) * w^1
// a[idx2] = (e - f) * w^2
// a[idx3] = (g + h) * w^3
e.addInto(f, a0);

g.subtractTimesIInto(h, a1);
a1.multiplyConjugate(omega1);

e.subtractInto(f, a2);
a2.multiplyConjugate(omega2);

g.addTimesIInto(h, a3);
a3.multiply(omega1); // Bernstein's trick: multiply by omega^(-1) instead of omega^3

a0.copyInto(a, idx0);
a1.copyInto(a, idx1);
a2.copyInto(a, idx2);
a3.copyInto(a, idx3);
}
}
}

// do one final radix-2 step if there is an odd number of stages
if (s > 0) {
for (int i = 0; i < n; i += 2) {
Expand Down Expand Up @@ -985,6 +1273,10 @@ void timesTwoToThe(int idxa, int n) {
a[ri] = fastScalb(real, n);
a[ii] = fastScalb(imag, n);
}

public int length() {
return length;
}
}

final static class MutableComplex {
Expand Down
Loading