Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tt-train] Add bf16 support #17821

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

[tt-train] Add bf16 support #17821

wants to merge 4 commits into from

Conversation

dmakoviichuk-tt
Copy link
Contributor

Problem description

Sometimes during tests it more convenient to use bf16 arithmetic to compare vs ttnn.

What's changed

  • add bf16 class
  • add arithmetics tests for bf16
  • add one xtensor test

Checklist

  • New/Existing tests provide coverage for changes

Copy link
Contributor

@rfurko-tt rfurko-tt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, a few questions and nits

@@ -5,11 +5,11 @@
#pragma once

#include <core/ttnn_all_includes.hpp>
#include <core/xtensor_utils.hpp>
#include <span>
#include <ttnn/tensor/shape/shape.hpp>
#include <ttnn/tensor/xtensor/conversion_utils.hpp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove ttnn headers (not scope of this PR, but helpful)

@@ -0,0 +1,7 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2025

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🐒

}

constexpr static float bfloat16_to_float(uint16_t v) noexcept {
std::array<uint16_t, 2> raw_arr = {{0, v}};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why {{ }}? isn't one pair of brackets enough?

@@ -0,0 +1,185 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2025

//
// SPDX-License-Identifier: Apache-2.0

#include <sys/types.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need sys/types?


constexpr inline operator double() const noexcept {
return static_cast<double>(bfloat16_to_float(m_raw_value));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move a few functions to cpp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constexpr and I want all of them to be inlined.


constexpr inline bfloat16 operator*(const bfloat16 &rhs) const noexcept {
float lhs_f = static_cast<float>(*this);
float rhs_f = static_cast<float>(rhs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline that we will use float operations, as we don't need bit-perfect comparison

case FP_SUBNORMAL:
case FP_ZERO:
raw_res = raw_arr[1];
raw_res &= 0x8000;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just set raw_res to 0? as we currently doing it in default constructor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is copypaste from onednn. Don't want to mess with some cornercases.

float large_f = 65504.0f; // near max for float16, but let's see for ttml::math::bfloat16
ttml::math::bfloat16 large_bf(large_f);
float large_f_back = static_cast<float>(large_bf);
std::cout << "large_f_back: " << large_f_back << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove cout? feel free to ignore if wish to keep

@ayerofieiev-tt ayerofieiev-tt changed the title [tt-train ]add bf16 support [tt-train] Add bf16 support Feb 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants