diff --git a/fastdoubleparser-dev/src/main/java/ch.randelshofer.fastdoubleparser/ch/randelshofer/fastdoubleparser/FftMultiplier.java b/fastdoubleparser-dev/src/main/java/ch.randelshofer.fastdoubleparser/ch/randelshofer/fastdoubleparser/FftMultiplier.java index 494720d7..55741857 100644 --- a/fastdoubleparser-dev/src/main/java/ch.randelshofer.fastdoubleparser/ch/randelshofer/fastdoubleparser/FftMultiplier.java +++ b/fastdoubleparser-dev/src/main/java/ch.randelshofer.fastdoubleparser/ch/randelshofer/fastdoubleparser/FftMultiplier.java @@ -150,7 +150,7 @@ private static ComplexVector calculateRootsOfUnity(int n) { * {@code roots[s][k] = e^(pi*k*i/(2*roots.length))}, * i.e., they must cover the first quadrant. */ - private static void fft(ComplexVector a, ComplexVector[] roots) { + static void fftOriginal(ComplexVector a, ComplexVector[] roots) { int n = a.length; int logN = 31 - Integer.numberOfLeadingZeros(n); MutableComplex a0 = new MutableComplex(); @@ -211,6 +211,294 @@ private static void fft(ComplexVector a, ComplexVector[] roots) { } } + // do one final radix-2 step if there is an odd number of stages + if (s > 0) { + for (int i = 0; i < n; i += 2) { + // omega = 1 + + // a0 = a[i]; + // a1 = a[i + 1]; + // a[i] += a1; + // a[i + 1] = a0 - a1; + a.copyInto(i, a0); + a.copyInto(i + 1, a1); + a.add(i, a1); + a0.subtractInto(a1, a, i + 1); + } + } + } + + static void fft(ComplexVector a, ComplexVector[] roots) { + fftOptimizedLessVariables(a, roots); + } + + static void fftOptimizedLessVariables(ComplexVector a, ComplexVector[] roots) { + int n = a.length(); + int logN = 31 - Integer.numberOfLeadingZeros(n); + MutableComplex a0 = new MutableComplex(); + MutableComplex a1 = new MutableComplex(); + MutableComplex a2 = new MutableComplex(); + MutableComplex a3 = new MutableComplex(); + + // do two FFT stages at a time (radix-4) + MutableComplex omega1 = new MutableComplex(); + MutableComplex omega2 = new MutableComplex(); + MutableComplex e = new MutableComplex(); + MutableComplex h = new MutableComplex(); + + int s = logN; + for (; s >= 2; s -= 2) { + ComplexVector rootsS = roots[s - 2]; + int m = 1 << s; + for (int i = 0; i < n; i += m) { + final int m4 = m / 4; + for (int j = 0; j < m4; j++) { + omega1.set(rootsS, j); + // computing omega2 from omega1 is less accurate than Math.cos() and Math.sin(), + // but it is the same error we'd incur with radix-2, so we're not breaking the + // assumptions of the Percival paper. + omega1.squareInto(omega2); + + int index = i + j; + int idx0 = index; + index += m4; + int idx1 = index; + index += m4; + int idx2 = index; + index += m4; + int idx3 = index; + + // radix-4 butterfly: + // a[idx0] = (a[idx0] + a[idx1] + a[idx2] + a[idx3]) * w^0 + // a[idx1] = (a[idx0] + a[idx1]*(-i) + a[idx2]*(-1) + a[idx3]*i) * w^1 + // a[idx2] = (a[idx0] + a[idx1]*(-1) + a[idx2] + a[idx3]*(-1)) * w^2 + // a[idx3] = (a[idx0] + a[idx1]*i + a[idx2]*(-1) + a[idx3]*(-i)) * w^3 + // where w = omega1^(-1) = conjugate(omega1) + // can be reordered to + // a[idx0] = (a[idx0] + a[idx2]) + (a[idx1] + a[idx3])) * w^0 + // a[idx1] = (a[idx0] - a[idx2]) - (i)*(a[idx1] - a[idx3])) * w^1 + // a[idx2] = (a[idx0] + a[idx2]) - (a[idx1] + a[idx3])) * w^2 + // a[idx3] = (a[idx0] - a[idx2]) + (i)*(a[idx1] - a[idx3])) * w^3 + // we define + // e = (a[idx0] + a[idx2]); f = (a[idx1] + a[idx3]); + // g = (a[idx0] - a[idx2]); h = (a[idx1] - a[idx3]); + a.addInto(idx0, a, idx2, e); + a.subtractInto(idx0, a, idx2, a3); + a.addInto(idx1, a, idx3, a2); + a.subtractInto(idx1, a, idx3, h); + + // original equation after substitution (a2 ~ f, a3 ~ g) + // a[idx0] = (e + f) * w^0 + // a[idx1] = (g - h) * w^1 + // a[idx2] = (e - f) * w^2 + // a[idx3] = (g + h) * w^3 + e.addInto(a2, a0); + + a3.subtractTimesIInto(h, a1); + a1.multiplyConjugate(omega1); + + e.subtractInto(a2, a2); + a2.multiplyConjugate(omega2); + + a3.addTimesIInto(h, a3); + a3.multiply(omega1); // Bernstein's trick: multiply by omega^(-1) instead of omega^3 + + a0.copyInto(a, idx0); + a1.copyInto(a, idx1); + a2.copyInto(a, idx2); + a3.copyInto(a, idx3); + } + } + } + + // do one final radix-2 step if there is an odd number of stages + if (s > 0) { + for (int i = 0; i < n; i += 2) { + // omega = 1 + + // a0 = a[i]; + // a1 = a[i + IMAG]; + // a[i] += a1; + // a[i + IMAG] = a0 - a1; + a.copyInto(i, a0); + a.copyInto(i + ComplexVector.IMAG, a1); + a.add(i, a1); + a0.subtractInto(a1, a, i + 1); + } + } + } + + static void fftOptimizedLessVariablesOriginalIndex(ComplexVector a, ComplexVector[] roots) { + int n = a.length(); + int logN = 31 - Integer.numberOfLeadingZeros(n); + MutableComplex a0 = new MutableComplex(); + MutableComplex a1 = new MutableComplex(); + MutableComplex a2 = new MutableComplex(); + MutableComplex a3 = new MutableComplex(); + + // do two FFT stages at a time (radix-4) + MutableComplex omega1 = new MutableComplex(); + MutableComplex omega2 = new MutableComplex(); + MutableComplex e = new MutableComplex(); + MutableComplex h = new MutableComplex(); + + int s = logN; + for (; s >= 2; s -= 2) { + ComplexVector rootsS = roots[s - 2]; + int m = 1 << s; + for (int i = 0; i < n; i += m) { + for (int j = 0; j < m / 4; j++) { + omega1.set(rootsS, j); + // computing omega2 from omega1 is less accurate than Math.cos() and Math.sin(), + // but it is the same error we'd incur with radix-2, so we're not breaking the + // assumptions of the Percival paper. + omega1.squareInto(omega2); + + int idx0 = i + j; + int idx1 = i + j + m / 4; + int idx2 = i + j + m / 2; + int idx3 = i + j + m * 3 / 4; + + // radix-4 butterfly: + // a[idx0] = (a[idx0] + a[idx1] + a[idx2] + a[idx3]) * w^0 + // a[idx1] = (a[idx0] + a[idx1]*(-i) + a[idx2]*(-1) + a[idx3]*i) * w^1 + // a[idx2] = (a[idx0] + a[idx1]*(-1) + a[idx2] + a[idx3]*(-1)) * w^2 + // a[idx3] = (a[idx0] + a[idx1]*i + a[idx2]*(-1) + a[idx3]*(-i)) * w^3 + // where w = omega1^(-1) = conjugate(omega1) + // can be reordered to + // a[idx0] = (a[idx0] + a[idx2]) + (a[idx1] + a[idx3])) * w^0 + // a[idx1] = (a[idx0] - a[idx2]) - (i)*(a[idx1] - a[idx3])) * w^1 + // a[idx2] = (a[idx0] + a[idx2]) - (a[idx1] + a[idx3])) * w^2 + // a[idx3] = (a[idx0] - a[idx2]) + (i)*(a[idx1] - a[idx3])) * w^3 + // we define + // e = (a[idx0] + a[idx2]); f = (a[idx1] + a[idx3]); + // g = (a[idx0] - a[idx2]); h = (a[idx1] - a[idx3]); + a.addInto(idx0, a, idx2, e); + a.subtractInto(idx0, a, idx2, a3); + a.addInto(idx1, a, idx3, a2); + a.subtractInto(idx1, a, idx3, h); + + // original equation after substitution (a2 ~ f, a3 ~ g) + // a[idx0] = (e + f) * w^0 + // a[idx1] = (g - h) * w^1 + // a[idx2] = (e - f) * w^2 + // a[idx3] = (g + h) * w^3 + e.addInto(a2, a0); + + a3.subtractTimesIInto(h, a1); + a1.multiplyConjugate(omega1); + + e.subtractInto(a2, a2); + a2.multiplyConjugate(omega2); + + a3.addTimesIInto(h, a3); + a3.multiply(omega1); // Bernstein's trick: multiply by omega^(-1) instead of omega^3 + + a0.copyInto(a, idx0); + a1.copyInto(a, idx1); + a2.copyInto(a, idx2); + a3.copyInto(a, idx3); + } + } + } + + // do one final radix-2 step if there is an odd number of stages + if (s > 0) { + for (int i = 0; i < n; i += 2) { + // omega = 1 + + // a0 = a[i]; + // a1 = a[i + IMAG]; + // a[i] += a1; + // a[i + IMAG] = a0 - a1; + a.copyInto(i, a0); + a.copyInto(i + ComplexVector.IMAG, a1); + a.add(i, a1); + a0.subtractInto(a1, a, i + 1); + } + } + } + static void fftOptimized(ComplexVector a, ComplexVector[] roots) { + int n = a.length(); + int logN = 31 - Integer.numberOfLeadingZeros(n); + MutableComplex a0 = new MutableComplex(); + MutableComplex a1 = new MutableComplex(); + MutableComplex a2 = new MutableComplex(); + MutableComplex a3 = new MutableComplex(); + + // do two FFT stages at a time (radix-4) + MutableComplex omega1 = new MutableComplex(); + MutableComplex omega2 = new MutableComplex(); + MutableComplex e = new MutableComplex(); + MutableComplex f = new MutableComplex(); + MutableComplex g = new MutableComplex(); + MutableComplex h = new MutableComplex(); + + int s = logN; + for (; s >= 2; s -= 2) { + ComplexVector rootsS = roots[s - 2]; + int m = 1 << s; + for (int i = 0; i < n; i += m) { + final int m4 = m / 4; + for (int j = 0; j < m4; j++) { + omega1.set(rootsS, j); + // computing omega2 from omega1 is less accurate than Math.cos() and Math.sin(), + // but it is the same error we'd incur with radix-2, so we're not breaking the + // assumptions of the Percival paper. + omega1.squareInto(omega2); + + int index = i + j; + int idx0 = index; + index += m4; + int idx1 = index; + index += m4; + int idx2 = index; + index += m4; + int idx3 = index; + + // radix-4 butterfly: + // a[idx0] = (a[idx0] + a[idx1] + a[idx2] + a[idx3]) * w^0 + // a[idx1] = (a[idx0] + a[idx1]*(-i) + a[idx2]*(-1) + a[idx3]*i) * w^1 + // a[idx2] = (a[idx0] + a[idx1]*(-1) + a[idx2] + a[idx3]*(-1)) * w^2 + // a[idx3] = (a[idx0] + a[idx1]*i + a[idx2]*(-1) + a[idx3]*(-i)) * w^3 + // where w = omega1^(-1) = conjugate(omega1) + // can be reordered to + // a[idx0] = (a[idx0] + a[idx2]) + (a[idx1] + a[idx3])) * w^0 + // a[idx1] = (a[idx0] - a[idx2]) - (i)*(a[idx1] - a[idx3])) * w^1 + // a[idx2] = (a[idx0] + a[idx2]) - (a[idx1] + a[idx3])) * w^2 + // a[idx3] = (a[idx0] - a[idx2]) + (i)*(a[idx1] - a[idx3])) * w^3 + // we define + // e = (a[idx0] + a[idx2]); f = (a[idx1] + a[idx3]); + // g = (a[idx0] - a[idx2]); h = (a[idx1] - a[idx3]); + a.addInto(idx0, a, idx2, e); + a.subtractInto(idx0, a, idx2, g); + a.addInto(idx1, a, idx3, f); + a.subtractInto(idx1, a, idx3, h); + + // original equation after substitution + // a[idx0] = (e + f) * w^0 + // a[idx1] = (g - h) * w^1 + // a[idx2] = (e - f) * w^2 + // a[idx3] = (g + h) * w^3 + e.addInto(f, a0); + + g.subtractTimesIInto(h, a1); + a1.multiplyConjugate(omega1); + + e.subtractInto(f, a2); + a2.multiplyConjugate(omega2); + + g.addTimesIInto(h, a3); + a3.multiply(omega1); // Bernstein's trick: multiply by omega^(-1) instead of omega^3 + + a0.copyInto(a, idx0); + a1.copyInto(a, idx1); + a2.copyInto(a, idx2); + a3.copyInto(a, idx3); + } + } + } + // do one final radix-2 step if there is an odd number of stages if (s > 0) { for (int i = 0; i < n; i += 2) { @@ -985,6 +1273,10 @@ void timesTwoToThe(int idxa, int n) { a[ri] = fastScalb(real, n); a[ii] = fastScalb(imag, n); } + + public int length() { + return length; + } } final static class MutableComplex { diff --git a/fastdoubleparser-dev/src/test/java/ch.randelshofer.fastdoubleparser/ch/randelshofer/fastdoubleparser/JmhFftMultiplierFft.java b/fastdoubleparser-dev/src/test/java/ch.randelshofer.fastdoubleparser/ch/randelshofer/fastdoubleparser/JmhFftMultiplierFft.java new file mode 100644 index 00000000..bd4a781f --- /dev/null +++ b/fastdoubleparser-dev/src/test/java/ch.randelshofer.fastdoubleparser/ch/randelshofer/fastdoubleparser/JmhFftMultiplierFft.java @@ -0,0 +1,107 @@ +/* + * @(#)JmhFftMultiplierFft.java + * Copyright © 2023 Werner Randelshofer, Switzerland. MIT License. + */ +package ch.randelshofer.fastdoubleparser; + +import ch.randelshofer.fastdoubleparser.FftMultiplier.ComplexVector; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +import java.util.concurrent.TimeUnit; + +/** + *
+ * # JMH version: 1.36 + * # VM version: JDK 20.0.1, OpenJDK 64-Bit Server VM, 20.0.1+9-29 + * # Intel(R) Core(TM) i5-6200U CPU @ 2.30GHz + * + * Benchmark Mode Cnt Score Error Units + * original avgt 5 16084.425 ± 163.063 ns/op + * sums avgt 5 17349.120 ± 86.367 ns/op + * sums_less_variables avgt 5 14540.706 ± 145.691 ns/op + * sums_less_variables_original_index avgt 5 14694.344 ± 416.344 ns/op + * + * Process finished with exit code 0 + *+ */ +@Fork(value = 1, jvmArgs = { + "-XX:+UnlockExperimentalVMOptions", "--add-modules", "jdk.incubator.vector" + , "--enable-preview" + //, "-Xmx4g" +}) +@Measurement(iterations = 5, time = 1) +@Warmup(iterations = 4, time = 1) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@BenchmarkMode(Mode.AverageTime) +@State(Scope.Benchmark) +public class JmhFftMultiplierFft { + + private static final int K = 10; + private static final int N = 1 << K; + + ComplexVector a = new ComplexVector(N); + ComplexVector[] roots = new ComplexVector[K]; + + @Setup(Level.Iteration) + public void setUp() { + for (int i = 0; i < N; i++) { + // just some values, but zeroes + a.real(i, i); + a.imag(i, i * 200 >>> 8); + } + for (int i = 0; i < K; i++) { + roots[i] = new ComplexVector(N); + for (int j = 0; j < N; j++) { + roots[i].real(j, j + i + 4); + roots[i].imag(j, j * 250 >>> 7); + } + } + } + + // @formatter:off + @Benchmark + public void original(Blackhole blackhole) { run(FftAlgorithm.ORIGINAL, blackhole); } + @Benchmark + public void sums(Blackhole blackhole) { run(FftAlgorithm.OPTIMIZED, blackhole); } + @Benchmark + public void sums_less_variables(Blackhole blackhole) { run(FftAlgorithm.OPTIMIZED_LESS_VARIABLES, blackhole); } + @Benchmark + public void sums_less_variables_original_index(Blackhole blackhole) { run(FftAlgorithm.OPTIMIZED_VARIABLE_REDUCED_ORIGINAL_INDEX, blackhole); } + // @formatter:on + + private enum FftAlgorithm { + ORIGINAL { + @Override + void fft(ComplexVector a, ComplexVector[] roots) { + FftMultiplier.fftOriginal(a, roots); + } + }, + OPTIMIZED { + @Override + void fft(ComplexVector a, ComplexVector[] roots) { + FftMultiplier.fftOptimized(a, roots); + } + }, + OPTIMIZED_LESS_VARIABLES { + @Override + void fft(ComplexVector a, ComplexVector[] roots) { + FftMultiplier.fftOptimizedLessVariables(a, roots); + } + }, + OPTIMIZED_VARIABLE_REDUCED_ORIGINAL_INDEX { + @Override + void fft(ComplexVector a, ComplexVector[] roots) { + FftMultiplier.fftOptimizedLessVariablesOriginalIndex(a, roots); + } + }; + + @SuppressWarnings("SameParameterValue") + abstract void fft(ComplexVector a, ComplexVector[] roots); + } + + private void run(FftAlgorithm algorithm, Blackhole blackhole) { + algorithm.fft(a, roots); + blackhole.consume(a); + } +}