Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
no functional change
  • Loading branch information
xu-shawn committed Jan 19, 2025
1 parent 8e3e22b commit bab6dfc
Showing 1 changed file with 72 additions and 55 deletions.
127 changes: 72 additions & 55 deletions src/nnue/nnue_feature_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cstdint>
#include <cstring>
#include <iosfwd>
#include <type_traits>
#include <utility>

#include "../position.h"
Expand Down Expand Up @@ -145,14 +146,50 @@ using psqt_vec_t = int32x4_t;

#endif

// Returns the inverse of a permutation
template<std::size_t Len>
constexpr std::array<std::size_t, Len> inverse_order(const std::array<std::size_t, Len>& order) {
std::array<std::size_t, Len> reversed{};
for (std::size_t i = 0; i < order.size(); i++)
reversed[order[i]] = i;
return reversed;
}

// Divide a byte region of size TotalSize to chunks of size
// BlockSize, and permute the blocks by a given order
template<std::size_t TotalSize, std::size_t BlockSize, std::size_t OrderSize>
void permute(void* const data, const std::array<std::size_t, OrderSize>& order) {
static_assert(TotalSize % (BlockSize * OrderSize) == 0,
"ChunkSize * OrderSize must perfectly divide TotalSize");
constexpr std::size_t ProcessChunkSize = BlockSize * OrderSize;

std::array<std::byte, ProcessChunkSize> buffer{};

std::byte* const bytes = reinterpret_cast<std::byte*>(data);

for (std::size_t i = 0; i < TotalSize; i += ProcessChunkSize)
{
std::byte* const values = &bytes[i];

for (std::size_t j = 0; j < OrderSize; j++)
{
auto* const buffer_chunk = &buffer[j * BlockSize];
auto* const value_chunk = &values[order[j] * BlockSize];

std::copy(value_chunk, value_chunk + BlockSize, buffer_chunk);
}

std::copy(std::begin(buffer), std::end(buffer), values);
}
}

// Compute optimal SIMD register count for feature transformer accumulation.
template<IndexType TransformedFeatureWidth, IndexType HalfDimensions>
class SIMDTiling {
#ifdef VECTOR
// We use __m* types as template arguments, which causes GCC to emit warnings
// about losing some attribute information. This is irrelevant to us as we
// only take their size, so the following pragma are harmless.
// We use __m* types as template arguments, which causes GCC to emit warnings
// about losing some attribute information. This is irrelevant to us as we
// only take their size, so the following pragma are harmless.
#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
Expand Down Expand Up @@ -223,62 +260,42 @@ class FeatureTransformer {
// Size of forward propagation buffer
static constexpr std::size_t BufferSize = OutputDimensions * sizeof(OutputType);

// Store the order by which 128-bit blocks of a 1024-bit data must
// be permuted so that calling packus on adjacent vectors of 16-bit
// integers loaded from the data results in the pre-permutation order
static constexpr auto PackusEpi16Order = []() -> std::array<std::size_t, 8> {
#if defined(USE_AVX512)
// _mm512_packus_epi16 after permutation:
// | 0 | 2 | 4 | 6 | // Vector 0
// | 1 | 3 | 5 | 7 | // Vector 1
// | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | // Packed Result
return {0, 2, 4, 6, 1, 3, 5, 7};
#elif defined(USE_AVX2)
// _mm256_packus_epi16 after permutation:
// | 0 | 2 | | 4 | 6 | // Vector 0, 2
// | 1 | 3 | | 5 | 7 | // Vector 1, 3
// | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | // Packed Result
return {0, 2, 1, 3, 4, 6, 5, 7};
#else
return {0, 1, 2, 3, 4, 5, 6, 7};
#endif
}();

static constexpr auto InversePackusEpi16Order = inverse_order(PackusEpi16Order);

// Hash value embedded in the evaluation file
static constexpr std::uint32_t get_hash_value() {
return FeatureSet::HashValue ^ (OutputDimensions * 2);
}

static constexpr void order_packs([[maybe_unused]] uint64_t* v) {
#if defined(USE_AVX512) // _mm512_packs_epi16 ordering
uint64_t tmp0 = v[2], tmp1 = v[3];
v[2] = v[8], v[3] = v[9];
v[8] = v[4], v[9] = v[5];
v[4] = tmp0, v[5] = tmp1;
tmp0 = v[6], tmp1 = v[7];
v[6] = v[10], v[7] = v[11];
v[10] = v[12], v[11] = v[13];
v[12] = tmp0, v[13] = tmp1;
#elif defined(USE_AVX2) // _mm256_packs_epi16 ordering
std::swap(v[2], v[4]);
std::swap(v[3], v[5]);
#endif
}

static constexpr void inverse_order_packs([[maybe_unused]] uint64_t* v) {
#if defined(USE_AVX512) // Inverse _mm512_packs_epi16 ordering
uint64_t tmp0 = v[2], tmp1 = v[3];
v[2] = v[4], v[3] = v[5];
v[4] = v[8], v[5] = v[9];
v[8] = tmp0, v[9] = tmp1;
tmp0 = v[6], tmp1 = v[7];
v[6] = v[12], v[7] = v[13];
v[12] = v[10], v[13] = v[11];
v[10] = tmp0, v[11] = tmp1;
#elif defined(USE_AVX2) // Inverse _mm256_packs_epi16 ordering
std::swap(v[2], v[4]);
std::swap(v[3], v[5]);
#endif
void permute_weights() {
permute<sizeof(biases), 16>(biases, PackusEpi16Order);
permute<sizeof(weights), 16>(weights, PackusEpi16Order);
}

void permute_weights([[maybe_unused]] void (*order_fn)(uint64_t*)) {
#if defined(USE_AVX2)
#if defined(USE_AVX512)
constexpr IndexType di = 16;
#else
constexpr IndexType di = 8;
#endif
uint64_t* b = reinterpret_cast<uint64_t*>(&biases[0]);
for (IndexType i = 0; i < HalfDimensions * sizeof(BiasType) / sizeof(uint64_t); i += di)
order_fn(&b[i]);

for (IndexType j = 0; j < InputDimensions; ++j)
{
uint64_t* w = reinterpret_cast<uint64_t*>(&weights[j * HalfDimensions]);
for (IndexType i = 0; i < HalfDimensions * sizeof(WeightType) / sizeof(uint64_t);
i += di)
order_fn(&w[i]);
}
#endif
void unpermute_weights() {
permute<sizeof(biases), 16>(biases, InversePackusEpi16Order);
permute<sizeof(weights), 16>(weights, InversePackusEpi16Order);
}

inline void scale_weights(bool read) {
Expand All @@ -300,22 +317,22 @@ class FeatureTransformer {
read_leb_128<WeightType>(stream, weights, HalfDimensions * InputDimensions);
read_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);

permute_weights(inverse_order_packs);
permute_weights();
scale_weights(true);
return !stream.fail();
}

// Write network parameters
bool write_parameters(std::ostream& stream) {

permute_weights(order_packs);
unpermute_weights();
scale_weights(false);

write_leb_128<BiasType>(stream, biases, HalfDimensions);
write_leb_128<WeightType>(stream, weights, HalfDimensions * InputDimensions);
write_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);

permute_weights(inverse_order_packs);
permute_weights();
scale_weights(true);
return !stream.fail();
}
Expand Down

0 comments on commit bab6dfc

Please sign in to comment.