Skip to content

Commit

Permalink
Merge pull request #14 from stevenewald/performance
Browse files Browse the repository at this point in the history
Performance
  • Loading branch information
stevenewald authored Jan 11, 2025
2 parents b33374b + 7cb488d commit a06f3da
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 27 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ check_cxx_source_compiles("
return 0;
}" HAS_NEON)

if (HAS_ALL_AVX512)
if (HAS_AVX512)
message(STATUS "AVX-512 is supported by the compiler.")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mavx512dq -mavx512vl -mavx512bf16")
target_sources(fractal-generator_lib PRIVATE source/mandelbrot/equations_simd.cpp)
Expand Down
2 changes: 1 addition & 1 deletion Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ tasks:
dir: .
cmds:
- task: build
- ./build/benchmark/fractal_benchmarks --benchmark_repetitions=5 --benchmark_min_warmup_time=0.5 --benchmark_report_aggregates_only=true
- sudo chrt --fifo 99 taskset 3 ./build/benchmark/fractal_benchmarks --benchmark_repetitions=5 --benchmark_min_warmup_time=0.5 --benchmark_report_aggregates_only=true

fmt:
dir: .
Expand Down
2 changes: 1 addition & 1 deletion benchmark/source/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ static void BM_GenerateMandelbrotSimd(benchmark::State& state)
for (auto _ : state) {
for (auto it = display.begin(); it != display.end(); it += 8) {
benchmark::DoNotOptimize(
compute_iterations(start, t.to_complex_projections(*it), 100)
compute_iterations(start, t.to_complex_projections(*it), 5000)
);
}
prox += display.size();
Expand Down
41 changes: 18 additions & 23 deletions source/mandelbrot/equations_simd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,20 @@ std::array<iteration_count, 8> compute_iterations(
__m512d input_vec_constant_imags = _mm512_load_pd(const_imags.data());
__m512d input_vec_constant_reals = _mm512_load_pd(const_reals.data());

__m128i solved_its_vec = _mm_set1_epi16(0);
__m256i solved_its_vec = _mm256_set1_epi32(0);
const __m512d squared_divergence_vec = _mm512_set1_pd(SQUARED_DIVERGENCE);
__mmask8 active_mask = 0xFF;

for (iteration_count iterations = 0; iterations < max_iters; iterations++) {
// load current values
__m512d x = input_vec_real;
__m512d y = input_vec_imag;

// compute squares and product
__m512d x_squared = _mm512_mul_pd(x, x);
__m512d y_squared = _mm512_mul_pd(y, y);
__m512d xy = _mm512_mul_pd(x, y);
__m512d x_squared = _mm512_mul_pd(input_vec_real, input_vec_real);
__m512d y_squared = _mm512_mul_pd(input_vec_imag, input_vec_imag);
__m512d xy = _mm512_mul_pd(input_vec_real, input_vec_imag);

// update real part: input_vec_real = x_squared - y_squared + constant_reals
__m512d temp_real = _mm512_sub_pd(x_squared, y_squared);
input_vec_real = _mm512_add_pd(temp_real, input_vec_constant_reals);
input_vec_real = _mm512_add_pd(
_mm512_sub_pd(x_squared, y_squared), input_vec_constant_reals
);

// update imaginary part: input_vec_imag = 2 * xy + constant_imags
input_vec_imag =
Expand All @@ -49,30 +46,28 @@ std::array<iteration_count, 8> compute_iterations(
__m512d squared_norms_vec = _mm512_add_pd(x_squared, y_squared);

// determine which elements have diverged
__mmask8 solved_mask =
_mm512_cmp_pd_mask(squared_norms_vec, squared_divergence_vec, _CMP_GT_OS);
active_mask =
_mm512_cmp_pd_mask(squared_norms_vec, squared_divergence_vec, _CMP_LE_OS);

// update iteration counts for elements that have just diverged
solved_its_vec = _mm_mask_blend_epi16(
solved_mask, solved_its_vec,
_mm_set1_epi16(static_cast<int16_t>(iterations))
solved_its_vec = _mm256_mask_blend_epi32(
active_mask, solved_its_vec, _mm256_set1_epi32(static_cast<int>(iterations))
);

// update active mask to skip computations for diverged elements
active_mask = _kandn_mask8(solved_mask, active_mask);

// break if all elements have diverged
if (active_mask == 0) [[unlikely]]
break;
}

__mmask8 mask = _mm_cmpeq_epi16_mask(solved_its_vec, _mm_set1_epi16(0));
solved_its_vec = _mm_mask_mov_epi16(
solved_its_vec, mask, _mm_set1_epi16(static_cast<int16_t>(max_iters))
__mmask8 mask = _mm256_cmpeq_epi32_mask(
solved_its_vec, _mm256_set1_epi32(static_cast<int>(max_iters) - 1)
);
solved_its_vec = _mm256_mask_mov_epi32(
solved_its_vec, mask, _mm256_set1_epi32(static_cast<int16_t>(max_iters))
);

alignas(16) std::array<iteration_count, 8> ret{};
_mm_storeu_epi16(ret.data(), solved_its_vec);
alignas(32) std::array<iteration_count, 8> ret{};
_mm256_storeu_epi32(ret.data(), solved_its_vec);

return ret;
}
Expand Down
2 changes: 1 addition & 1 deletion source/units/units.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <cstdint>

namespace fractal {
using iteration_count = std::uint16_t;
using iteration_count = std::uint32_t;

using complex_underlying = double;

Expand Down

0 comments on commit a06f3da

Please sign in to comment.