Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tt-train] Add bf16 support #17821

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tt-train/sources/ttml/core/xtensor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
#pragma once

#include <core/ttnn_all_includes.hpp>
#include <core/xtensor_utils.hpp>
#include <span>
#include <ttnn/tensor/shape/shape.hpp>
#include <ttnn/tensor/xtensor/conversion_utils.hpp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove ttnn headers (not scope of this PR, but helpful)

#include <ttnn/tensor/xtensor/partition.hpp>

// TODO: decide if we want to use xarray everwhere or xtensor is ok
/*
Difference between xtensor and xarray:
Expand Down
7 changes: 7 additions & 0 deletions tt-train/sources/ttml/math/bf16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2025

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🐒

//
// SPDX-License-Identifier: Apache-2.0

#include "bf16.hpp"

namespace ttml::math {}
126 changes: 126 additions & 0 deletions tt-train/sources/ttml/math/bf16.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include <sys/types.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need sys/types?


#include <array>
#include <cstdint>

namespace ttml::math {

class bfloat16 {
public:
uint16_t m_raw_value = 0;

bfloat16() = default;

constexpr inline explicit bfloat16(float f) noexcept {
m_raw_value = float_to_bfloat16(f);
}

constexpr inline explicit bfloat16(double d) noexcept {
m_raw_value = float_to_bfloat16(static_cast<float>(d));
}

constexpr inline operator float() const noexcept {
return bfloat16_to_float(m_raw_value);
}

constexpr inline operator double() const noexcept {
return static_cast<double>(bfloat16_to_float(m_raw_value));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move a few functions to cpp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constexpr and I want all of them to be inlined.


constexpr inline bfloat16 operator+(const bfloat16 &rhs) const noexcept {
float lhs_f = static_cast<float>(*this);
float rhs_f = static_cast<float>(rhs);
return bfloat16(lhs_f + rhs_f);
}

constexpr inline bfloat16 operator-(const bfloat16 &rhs) const noexcept {
float lhs_f = static_cast<float>(*this);
float rhs_f = static_cast<float>(rhs);
return bfloat16(lhs_f - rhs_f);
}

constexpr inline bfloat16 operator*(const bfloat16 &rhs) const noexcept {
float lhs_f = static_cast<float>(*this);
float rhs_f = static_cast<float>(rhs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline that we will use float operations, as we don't need bit-perfect comparison

return bfloat16(lhs_f * rhs_f);
}

constexpr inline bfloat16 operator/(const bfloat16 &rhs) const noexcept {
float lhs_f = static_cast<float>(*this);
float rhs_f = static_cast<float>(rhs);
return bfloat16(lhs_f / rhs_f);
}

constexpr inline bfloat16 &operator+=(const bfloat16 &rhs) noexcept {
*this = *this + rhs;
return *this;
}
constexpr inline bfloat16 &operator-=(const bfloat16 &rhs) noexcept {
*this = *this - rhs;
return *this;
}
constexpr inline bfloat16 &operator*=(const bfloat16 &rhs) noexcept {
*this = *this * rhs;
return *this;
}
constexpr inline bfloat16 &operator/=(const bfloat16 &rhs) noexcept {
*this = *this / rhs;
return *this;
}

constexpr inline bool operator==(const bfloat16 &rhs) const noexcept {
return static_cast<float>(*this) == static_cast<float>(rhs);
}
constexpr inline bool operator!=(const bfloat16 &rhs) const noexcept {
return !(*this == rhs);
}
constexpr inline bool operator<(const bfloat16 &rhs) const noexcept {
return static_cast<float>(*this) < static_cast<float>(rhs);
}
constexpr inline bool operator>(const bfloat16 &rhs) const noexcept {
return rhs < *this;
}
constexpr inline bool operator<=(const bfloat16 &rhs) const noexcept {
return !(*this > rhs);
}
constexpr inline bool operator>=(const bfloat16 &rhs) const noexcept {
return !(*this < rhs);
}

private:
constexpr static uint16_t float_to_bfloat16(float f) noexcept {
std::array<uint16_t, 2> raw_arr = std::bit_cast<std::array<uint16_t, 2>>(f);
uint16_t raw_res = 0;

switch (std::fpclassify(f)) {
case FP_SUBNORMAL:
case FP_ZERO:
raw_res = raw_arr[1];
raw_res &= 0x8000;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just set raw_res to 0? as we currently doing it in default constructor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is copypaste from onednn. Don't want to mess with some cornercases.

break;
case FP_INFINITE: raw_res = raw_arr[1]; break;
case FP_NAN:
raw_res = raw_arr[1];
raw_res |= 1 << 6;
break;
case FP_NORMAL:
const uint32_t rounding_bias = 0x00007FFF + (raw_arr[1] & 0x1);
const uint32_t int_raw = std::bit_cast<uint32_t>(f) + rounding_bias;
raw_arr = std::bit_cast<std::array<uint16_t, 2>>(int_raw);
raw_res = raw_arr[1];
break;
}
return raw_res;
}

constexpr static float bfloat16_to_float(uint16_t v) noexcept {
std::array<uint16_t, 2> raw_arr = {{0, v}};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why {{ }}? isn't one pair of brackets enough?

return bit_cast<float>(raw_arr);
}
};

} // namespace ttml::math
185 changes: 185 additions & 0 deletions tt-train/tests/math/bf16_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2025

//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>

#include <cmath>
#include <core/xtensor_utils.hpp>

#include "math/bf16.hpp"
#include "xtensor/xmath.hpp"

// Test construction from float and reconversion
TEST(BFloat16Test, BasicConstruction) {
// 1) Zero
ttml::math::bfloat16 z(0.0f);
EXPECT_EQ(static_cast<float>(z), 0.0f);

// 2) Positive value
ttml::math::bfloat16 p(1.0f);
EXPECT_FLOAT_EQ(static_cast<float>(p), 1.0f);

// 3) Negative value
ttml::math::bfloat16 n(-2.5f);
EXPECT_FLOAT_EQ(static_cast<float>(n), -2.5f);
}

TEST(BFloat16Test, ValueRounding) {
// This test checks that rounding to nearest-even is happening.
// Example: 1.00007f is slightly more than 1, might round up or remain 1
float val = 1.00007f;
ttml::math::bfloat16 a(val);
float reconstructed = static_cast<float>(a);

// We can't say EXACT, because we expect it to be 1.0 or slightly more
// Check closeness with an appropriate epsilon
EXPECT_NEAR(reconstructed, val, 1e-3f);
}

TEST(BFloat16Test, ConversionDouble) {
// Construct from double
double d = 3.141592653589793;
ttml::math::bfloat16 bf(d);

// Check float equivalence
float f = static_cast<float>(bf);
EXPECT_NEAR(f, static_cast<float>(d), 1e-3f);
}
/*
import torch

# Create bfloat16 tensors for a and b
a = torch.tensor(1.5, dtype=torch.bfloat16)
b = torch.tensor(2.5, dtype=torch.bfloat16)

# Perform arithmetic operations
sum_val = a + b
diff_val = a - b
prod_val = a * b
quot_val = a / b

# Print results. Note that arithmetic with bfloat16 might internally use float32 for computation.
print("a =", a.item())
print("b =", b.item())
print("sum =", sum_val.item())
print("diff =", diff_val.item())
print("prod =", prod_val.item())
print("quot =", quot_val.item())

# Output:
a = 1.5
b = 2.5
sum = 4.0
diff = -1.0
prod = 3.75
quot = 0.6015625
*/
TEST(BFloat16Test, ArithmeticOperations) {
ttml::math::bfloat16 a(1.5f);
ttml::math::bfloat16 b(2.5f);

ttml::math::bfloat16 sum = a + b;
ttml::math::bfloat16 diff = a - b;
ttml::math::bfloat16 prod = a * b;
ttml::math::bfloat16 quot = a / b;

EXPECT_NEAR(static_cast<float>(sum), 4.0f, 1e-3f);
EXPECT_NEAR(static_cast<float>(diff), -1.0f, 1e-3f);
EXPECT_NEAR(static_cast<float>(prod), 3.75f, 1e-3f);
EXPECT_NEAR(static_cast<float>(quot), 0.6f, 1e-2f);

// Compound assignments
ttml::math::bfloat16 c(2.0f);
c += ttml::math::bfloat16(3.0f);
EXPECT_NEAR(static_cast<float>(c), 5.0f, 1e-3f);

c -= ttml::math::bfloat16(1.0f);
EXPECT_NEAR(static_cast<float>(c), 4.0f, 1e-3f);

c *= ttml::math::bfloat16(2.0f);
EXPECT_NEAR(static_cast<float>(c), 8.0f, 1e-3f);

c /= ttml::math::bfloat16(4.0f);
EXPECT_NEAR(static_cast<float>(c), 2.0f, 1e-3f);
}

TEST(BFloat16Test, ComparisonOperators) {
ttml::math::bfloat16 a(1.0f), b(2.0f), c(1.0f);

EXPECT_TRUE(a < b);
EXPECT_TRUE(a <= b);
EXPECT_TRUE(b > a);
EXPECT_TRUE(b >= a);
EXPECT_TRUE(a == c);
EXPECT_TRUE(a != b);
}
/*
import torch
import math

# Create a list with the desired values
values = [65504.0, -65504.0, float('inf'), float('-inf'), float('nan')]

# Create a tensor with dtype torch.bfloat16
bf16_tensor = torch.tensor(values, dtype=torch.bfloat16)

# Print the bfloat16 tensor
print("bfloat16 tensor:", bf16_tensor)

# Optionally, convert it back to float for a clearer view
print("Converted to float:", bf16_tensor.to(torch.float32))

# Output:

bfloat16 tensor: tensor([ 65536., -65536., inf, -inf, nan], dtype=torch.bfloat16)
Converted to float: tensor([ 65536., -65536., inf, -inf, nan])
*/
TEST(BFloat16Test, CornerCases) {
// Very large float
float large_f = 65504.0f; // near max for float16, but let's see for ttml::math::bfloat16
ttml::math::bfloat16 large_bf(large_f);
float large_f_back = static_cast<float>(large_bf);
std::cout << "large_f_back: " << large_f_back << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove cout? feel free to ignore if wish to keep

float expected_value = 65536; // 65504 + 32
EXPECT_NEAR(large_f_back, expected_value, 1e-1f);

// Negative large
float neg_large_f = -65504.0f;
ttml::math::bfloat16 neg_large_bf(neg_large_f);
float neg_large_f_back = static_cast<float>(neg_large_bf);
std::cout << "neg_large_f_back: " << neg_large_f_back << std::endl;
float expected_neg_value = -65536; // 65504 + 32
EXPECT_NEAR(neg_large_f_back, expected_neg_value, 1e-1f);

// Infinity
float inf = std::numeric_limits<float>::infinity();
ttml::math::bfloat16 bf_inf(inf);
float reconstructed_inf = static_cast<float>(bf_inf);
EXPECT_TRUE(std::isinf(reconstructed_inf));

// Negative Infinity
float neg_inf = -std::numeric_limits<float>::infinity();
ttml::math::bfloat16 bf_neg_inf(neg_inf);
float reconstructed_neg_inf = static_cast<float>(bf_neg_inf);
EXPECT_TRUE(std::isinf(reconstructed_neg_inf));
EXPECT_LT(reconstructed_neg_inf, 0.0f);

// NaN
float nan_val = std::numeric_limits<float>::quiet_NaN();
ttml::math::bfloat16 bf_nan(nan_val);
float reconstructed_nan = static_cast<float>(bf_nan);
EXPECT_TRUE(std::isnan(reconstructed_nan));
}
TEST(BFloat16Test, Xtensor) {
// Create an xtensor array of floats
xt::xarray<float> float_array = {1.5f, 2.5f, 3.5f};

xt::xarray<ttml::math::bfloat16> bf16_array = xt::cast<ttml::math::bfloat16>(float_array);
xt::xarray<float> sum_orig = float_array + float_array;
xt::xarray<ttml::math::bfloat16> sum_bf16 = bf16_array + bf16_array;
xt::xarray<float> bf16_sum_back = xt::cast<float>(sum_bf16);
std::cout << "sum_orig: " << sum_orig << std::endl;
std::cout << "sum_bf16: " << bf16_sum_back << std::endl;
EXPECT_TRUE(xt::allclose(bf16_sum_back, sum_orig, /*rtol=*/1e-3, /*atol=*/1e-2));
}
Loading