Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: [tt-train] DRAFT add new ttml ops #17814

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tt-train/sources/ttml/core/tt_tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,43 @@ template tt::tt_metal::Tensor from_xtensor<uint32_t, DataType::UINT32>(
const XTensorToMeshVariant<uint32_t>& composer,
Layout layout);

ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please throw if rank > 4 or 8?

auto logical_shape = t.get_logical_shape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sminakov-tt could you take a look at this function please?

auto physical_shape = t.get_padded_shape();
auto t_rank = logical_shape.rank();
TT_FATAL(t_rank <= rank, "Cannot unsqueeze to rank {} from rank {}", rank, t_rank);

ttnn::SmallVector<uint32_t> result_logical_shape(rank);
ttnn::SmallVector<uint32_t> result_physical_shape(rank);
std::fill(result_logical_shape.begin(), result_logical_shape.end(), 1);
std::fill(result_physical_shape.begin(), result_physical_shape.end(), 1);

auto rank_diff = rank - t_rank;
std::copy(logical_shape.cbegin(), logical_shape.cend(), result_logical_shape.begin() + rank_diff);
std::copy(physical_shape.cbegin(), physical_shape.cend(), result_physical_shape.begin() + rank_diff);
return ttnn::reshape(t, ttnn::Shape{result_logical_shape}, ttnn::Shape{result_physical_shape});
}

ttnn::Tensor squeeze_to_rank(const ttnn::Tensor& t, size_t rank) {
auto logical_shape = t.get_logical_shape();
auto physical_shape = t.get_padded_shape();
auto t_rank = logical_shape.rank();
TT_FATAL(t_rank >= rank, "Cannot squeeze to rank {} from rank {}", rank, t_rank);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't use TT_FATAL in tt-train. Just check and throw exception.


auto rank_diff = t_rank - rank;
bool leading_ones =
std::all_of(logical_shape.cbegin(), logical_shape.cbegin() + rank_diff, [](size_t dim) { return dim == 1; });
TT_FATAL(leading_ones, "Cannot squeeze shape {} to rank {}", logical_shape, rank);

ttnn::SmallVector<uint32_t> result_logical_shape(rank);
ttnn::SmallVector<uint32_t> result_physical_shape(rank);
std::fill(result_logical_shape.begin(), result_logical_shape.end(), 1);
std::fill(result_physical_shape.begin(), result_physical_shape.end(), 1);

std::copy(logical_shape.cbegin() + rank_diff, logical_shape.cend(), result_logical_shape.begin());
std::copy(physical_shape.cbegin() + rank_diff, physical_shape.cend(), result_physical_shape.begin());

return ttnn::reshape(t, ttnn::Shape{result_logical_shape}, ttnn::Shape{result_physical_shape});
}

} // namespace ttml::core
6 changes: 6 additions & 0 deletions tt-train/sources/ttml/core/tt_tensor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,10 @@ tt::tt_metal::Tensor from_xtensor(
const XTensorToMeshVariant<T>& composer,
Layout layout = Layout::TILE);

// Unsqueeze tensor to specified rank by adding leading dimensions of size 1
ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank);

// Squeeze tensor to specified rank by removing leading dimensions of size 1
ttnn::Tensor squeeze_to_rank(const ttnn::Tensor& t, size_t rank);

} // namespace ttml::core
107 changes: 105 additions & 2 deletions tt-train/sources/ttml/ops/binary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "binary_ops.hpp"

#include <core/compute_kernel_config.hpp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use ""

#include <core/ttnn_all_includes.hpp>
#include <memory>
#include <ttnn/operations/eltwise/binary/binary.hpp>
Expand Down Expand Up @@ -102,6 +103,42 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, const autograd::Tens
auto a_grad = ttnn::multiply(out->get_grad(), b->get_value());
auto b_grad = ttnn::multiply(out->get_grad(), a->get_value());

auto clamp_to_rank = [](const ttnn::Tensor& tensor, size_t rank) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jaykru-tt if you decided to add broadcasting it should work for all binary ops and be implemented in other way. So it should be functions which are independent from the exact op. If we decide to add this code to each op it would look pretty bad.

auto tensor_rank = tensor.logical_shape().rank();
if (tensor_rank == rank) {
return tensor;
} else if (tensor_rank > rank) {
return ttml::core::squeeze_to_rank(tensor, rank);
} else {
return ttml::core::unsqueeze_to_rank(tensor, rank);
}
};

auto logical_suffixes_match = [](const ttnn::Tensor& a, const ttnn::Tensor& b) {
auto a_shape = a.get_logical_shape();
auto b_shape = b.get_logical_shape();

auto suffix_len = std::min(a_shape.size(), b_shape.size());
for (auto i = -1; i >= -suffix_len; i--) {
if (a_shape[i] != b_shape[i]) {
return false;
}
}
return true;
};

if (a->get_value().logical_shape().rank() != a_grad.logical_shape().rank()) {
if (logical_suffixes_match(a->get_value(), a_grad)) {
a_grad = clamp_to_rank(a_grad, a->get_value().logical_shape().rank());
}
}

if (b->get_value().logical_shape().rank() != b_grad.logical_shape().rank()) {
if (logical_suffixes_match(b->get_value(), b_grad)) {
b_grad = clamp_to_rank(b_grad, b->get_value().logical_shape().rank());
}
}

a->add_grad(a_grad);
b->add_grad(b_grad);
};
Expand All @@ -124,6 +161,14 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, float b) {
return out;
}

autograd::TensorPtr operator*(float a, const autograd::TensorPtr& b) {
return b * a;
}

autograd::TensorPtr operator/(const autograd::TensorPtr& a, float b) {
return a * (1.F / b);
}

autograd::TensorPtr operator/(const autograd::TensorPtr& a, const autograd::TensorPtr& b) {
auto out = autograd::create_tensor();

Expand Down Expand Up @@ -155,12 +200,70 @@ autograd::TensorPtr mul(const autograd::TensorPtr& a, const autograd::TensorPtr&
return a * b;
}

autograd::TensorPtr mul(const autograd::TensorPtr& a, float b) {
return a * b;
}

autograd::TensorPtr mul(float a, const autograd::TensorPtr& b) {
return b * a;
}

autograd::TensorPtr div(const autograd::TensorPtr& a, const autograd::TensorPtr& b) {
return a / b;
}

autograd::TensorPtr mul(const autograd::TensorPtr& a, float b) {
return a * b;
autograd::TensorPtr div(const autograd::TensorPtr& a, float b) {
return a / b;
}

tt::tt_metal::Tensor ttnn_matmul(
const tt::tt_metal::Tensor& a, const tt::tt_metal::Tensor& b, bool transpose_a, bool transpose_b) {
return ttnn::matmul(
a,
b,
transpose_a,
transpose_b,
/* memory_config */ std::nullopt,
/* dtype */ std::nullopt,
/* program_config */ std::nullopt,
/* activation */ std::nullopt,
/* compute_kernel_config */ core::ComputeKernelConfig::matmul(),
/* core_grid */ std::nullopt, // NOTE: I believe matmul will use the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already had comment in other pr. Please use our core grid. If we decide to use default parameter it should be used everywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Def will change this, just didn't change it yet since it won't make it into the other PR.

// core grid for the device it ends up
// running on, but should confirm.
/* output_tile */ std::nullopt);
}

autograd::TensorPtr matmul(
const autograd::TensorPtr& a, const autograd::TensorPtr& b, bool transpose_a, bool transpose_b) {
auto out = autograd::create_tensor();
out->set_value(ttnn_matmul(a->get_value(), b->get_value(), transpose_a, transpose_b));

autograd::GradFunction grad = [a, b, out, transpose_a, transpose_b]() {
// For loss function L and matmul C = AB:
// dL/dA = dL/dC * B^T
// dL/dB = A^T * dL/dC

// where L is the loss function
auto grad_a = ttnn_matmul(
out->get_grad(),
b->get_value(),
/* transpose_a */ transpose_a,
/* transpose_b */ !transpose_b);
auto grad_b = ttnn_matmul(
a->get_value(),
out->get_grad(),
/* transpose_a */ !transpose_a,
/* transpose_b */ transpose_b);

a->add_grad(grad_a);
b->add_grad(grad_b);
};

auto links = autograd::get_links(a, b);
out->set_node(autograd::ctx().add_backward_node(std::move(grad), links));

return out;
}

} // namespace ttml::ops
7 changes: 7 additions & 0 deletions tt-train/sources/ttml/ops/binary_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@ autograd::TensorPtr operator+(const autograd::TensorPtr& a, const autograd::Auto
autograd::TensorPtr operator+(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr operator*(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr operator*(const autograd::TensorPtr& a, float b);
autograd::TensorPtr operator*(float a, const autograd::TensorPtr& b);
autograd::TensorPtr operator-(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr operator/(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr operator/(const autograd::TensorPtr& a, float b);

autograd::TensorPtr add(const autograd::TensorPtr& a, const autograd::AutocastTensor& b);
autograd::TensorPtr add(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr sub(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr mul(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr mul(const autograd::TensorPtr& a, float b);
autograd::TensorPtr mul(float a, const autograd::TensorPtr& b);
autograd::TensorPtr div(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr div(const autograd::TensorPtr& a, float b);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we have mul(sclar, tensor) we should have a div too.


autograd::TensorPtr matmul(
const autograd::TensorPtr& a, const autograd::TensorPtr& b, bool transpose_a, bool transpose_b);

} // namespace ttml::ops
32 changes: 32 additions & 0 deletions tt-train/sources/ttml/ops/unary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,36 @@ autograd::TensorPtr broadcast_batch(const autograd::TensorPtr& tensor, uint32_t
return out;
}

autograd::TensorPtr sqrt(const autograd::TensorPtr& tensor) {
auto out = autograd::create_tensor();
auto sqrt_tensor = ttnn::sqrt(tensor->get_value());
out->set_value(sqrt_tensor);
autograd::GradFunction grad = [&tensor, &out, &sqrt_tensor]() {
// dL/dx = dL/d(sqrt(x)) * 1/(2*sqrt(x))
auto grad = ttnn::divide(out->get_grad(), ttnn::multiply(sqrt_tensor, 2.F));
tensor->add_grad(grad);
};
auto links = autograd::get_links(tensor);
out->set_node(autograd::ctx().add_backward_node(std::move(grad), links));
return out;
}

autograd::TensorPtr sum(const autograd::TensorPtr& tensor) {
auto out = autograd::create_tensor();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like sum op without dims parameter.

out->set_value(ttml::ttnn_fixed::sum_moreh(tensor->get_value()));

autograd::GradFunction grad = [tensor, out]() {
// Distribute the gradient to each element in the original tensor
auto in_shape = tensor->get_value().get_logical_shape();
auto grad_shape = out->get_grad().get_logical_shape();

auto unsqueezed_grad = ttml::core::unsqueeze_to_rank(out->get_grad(), in_shape.rank());
auto grad_broadcast = ttnn::repeat(unsqueezed_grad, in_shape);
tensor->add_grad(grad_broadcast);
};

auto links = autograd::get_links(tensor);
out->set_node(autograd::ctx().add_backward_node(std::move(grad), links));
return out;
}
} // namespace ttml::ops
2 changes: 2 additions & 0 deletions tt-train/sources/ttml/ops/unary_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ autograd::TensorPtr sum(const autograd::TensorPtr& tensor);
autograd::TensorPtr broadcast_batch(const autograd::TensorPtr& tensor, uint32_t new_batch_dim);
autograd::TensorPtr log_softmax(const autograd::TensorPtr& tensor, int dim);
autograd::TensorPtr log_softmax_moreh(const autograd::TensorPtr& tensor, int dim);
autograd::TensorPtr sqrt(const autograd::TensorPtr& tensor);

} // namespace ttml::ops
9 changes: 9 additions & 0 deletions tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ tt::tt_metal::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep
/* device_compute_kernel_config */ core::ComputeKernelConfig::precise());
return res;
}

// Overload supporting generic sum over multiple dimensions
tt::tt_metal::Tensor sum_moreh(
const tt::tt_metal::Tensor& t, std::optional<ttnn::SmallVector<int64_t>> dims, bool keep_dim) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't use optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not super familiar with best practice for std::optional in C++, so I checked the Google C++ style guide. They write that one should use optional for by-value parameters that are optional. I want to give an easy way to sum over all dims without having to construct that small vector as the called.

In this case I have, I think, 4 options:

  1. Use std::optional. nullopt passed -> sum over all dims.
  2. Use const * and treat nullptr as nothing passed -> sum over all dims. This is suitable for this specific case because we should avoid passing the vec by value
  3. Special case empty vector as the all dims case and use that as the default value for the parameter. This seems arbitrary to me and isn't generally applicable to all types.
  4. Overload without the optional parameter.

Is nullable const * okay in this case? And in general, what should we do for passing by value where there isn't an obvious value to signal the nothing passed case?

Thanks in advance for your guidance 😁

tt::tt_metal::Tensor res =
ttnn::moreh_sum(t, dims, keep_dim, std::nullopt, std::nullopt, core::ComputeKernelConfig::precise());
return res;
}

tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim) {
return ttnn::sum(t, dim, keep_dim, std::nullopt, core::ComputeKernelConfig::precise());
}
Expand Down
4 changes: 4 additions & 0 deletions tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,9 @@ tt::tt_metal::Tensor mean_moreh(const tt::tt_metal::Tensor& t, int dim, bool kee
tt::tt_metal::Tensor mean_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);

tt::tt_metal::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);
tt::tt_metal::Tensor sum_moreh(
const tt::tt_metal::Tensor& t,
std::optional<ttnn::SmallVector<int64_t>> dims = std::nullopt,
bool keep_dim = false);
tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);
} // namespace ttml::ttnn_fixed
38 changes: 38 additions & 0 deletions tt-train/tests/ops/unary_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "autograd/auto_context.hpp"
#include "autograd/tensor.hpp"
#include "core/tt_tensor_utils.hpp"
#include "core/xtensor_utils.hpp"

class UnaryOpsTest : public ::testing::Test {
protected:
Expand Down Expand Up @@ -45,6 +46,43 @@ TEST_F(UnaryOpsTest, GlobalMean) {
}
}

TEST_F(UnaryOpsTest, Sum) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you must add tests for all your new ops:

  1. matmuls
  2. all new overloads of the mul, div and etc.
  3. broadcasting. (But Id make a separate pr for a broadcasting as it is a pretty complex feature)

xt::xarray<float> test_vector = {{1.F, 2.F, 3.F, 4.F}, {1.F, 2.F, 3.F, 4.F}};
auto test_tensor_ptr =
ttml::autograd::create_tensor(ttml::core::from_xtensor(test_vector, &ttml::autograd::ctx().get_device()));

auto result = ttml::ops::sum(test_tensor_ptr);
auto result_vector = ttml::core::to_xtensor(result->get_value());

ASSERT_TRUE(xt::allclose(result_vector, xt::sum(test_vector), 1e-5F));

result->backward();
auto test_tensor_grad = ttml::core::to_xtensor(test_tensor_ptr->get_grad());

ASSERT_TRUE(xt::allclose(xt::ones_like(test_vector), test_tensor_grad, 1e-5F));
}

TEST_F(UnaryOpsTest, Sqrt) {
xt::xarray<float> test_vector = {{1.F, 2.F, 3.F, 4.F}, {1.F, 2.F, 3.F, 4.F}};
auto test_tensor_ptr =
ttml::autograd::create_tensor(ttml::core::from_xtensor(test_vector, &ttml::autograd::ctx().get_device()));

auto result = ttml::ops::sqrt(test_tensor_ptr);
auto result_vector = ttml::core::to_xtensor(result->get_value());

std::cout << "result_vector: " << result_vector << std::endl;
std::cout << "test_vector: " << test_vector << std::endl;
std::cout << "xt::sqrt(test_vector): " << xt::sqrt(test_vector) << std::endl;

ASSERT_TRUE(xt::allclose(result_vector, xt::sqrt(test_vector), 1e-2F));

// FIXME(jaykru-tt): add grad test for sqrt
// result->backward();
// auto test_tensor_grad = ttml::core::to_xtensor(test_tensor_ptr->get_grad());

// ASSERT_TRUE(xt::allclose(xt::ones_like(test_vector), test_tensor_grad));
}

TEST_F(UnaryOpsTest, LogSoftmax) {
auto* device = &ttml::autograd::ctx().get_device();
std::vector<float> test_data = {-0.1F, -0.2F, -0.3F, -0.4F, 0.F, -0.2F, -0.3F, -0.4F};
Expand Down
Loading