-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tt-train] Add bf16 support #17821
base: main
Are you sure you want to change the base?
[tt-train] Add bf16 support #17821
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, a few questions and nits
@@ -5,11 +5,11 @@ | |||
#pragma once | |||
|
|||
#include <core/ttnn_all_includes.hpp> | |||
#include <core/xtensor_utils.hpp> | |||
#include <span> | |||
#include <ttnn/tensor/shape/shape.hpp> | |||
#include <ttnn/tensor/xtensor/conversion_utils.hpp> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove ttnn headers (not scope of this PR, but helpful)
tt-train/sources/ttml/math/bf16.cpp
Outdated
@@ -0,0 +1,7 @@ | |||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🐒
} | ||
|
||
constexpr static float bfloat16_to_float(uint16_t v) noexcept { | ||
std::array<uint16_t, 2> raw_arr = {{0, v}}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why {{ }}? isn't one pair of brackets enough?
@@ -0,0 +1,185 @@ | |||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2025
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <sys/types.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really need sys/types?
|
||
constexpr inline operator double() const noexcept { | ||
return static_cast<double>(bfloat16_to_float(m_raw_value)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we move a few functions to cpp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
constexpr and I want all of them to be inlined.
|
||
constexpr inline bfloat16 operator*(const bfloat16 &rhs) const noexcept { | ||
float lhs_f = static_cast<float>(*this); | ||
float rhs_f = static_cast<float>(rhs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussed offline that we will use float operations, as we don't need bit-perfect comparison
case FP_SUBNORMAL: | ||
case FP_ZERO: | ||
raw_res = raw_arr[1]; | ||
raw_res &= 0x8000; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just set raw_res to 0? as we currently doing it in default constructor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is copypaste from onednn. Don't want to mess with some cornercases.
float large_f = 65504.0f; // near max for float16, but let's see for ttml::math::bfloat16 | ||
ttml::math::bfloat16 large_bf(large_f); | ||
float large_f_back = static_cast<float>(large_bf); | ||
std::cout << "large_f_back: " << large_f_back << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove cout? feel free to ignore if wish to keep
…etal into DM/bf16_xtensor
Problem description
Sometimes during tests it more convenient to use bf16 arithmetic to compare vs ttnn.
What's changed
Checklist