Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: [tt-train] DRAFT add new ttml ops #17814

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Conversation

jaykru-tt
Copy link
Contributor

@jaykru-tt jaykru-tt commented Feb 11, 2025

Placeholder PR to save some work that spilled out of another PR:

  • Adds matmul and sqrt with backward support
  • Attempt at adding generic backward with broadcasting for ttml::mul; IIRC this is incomplete and ought to be revisited.

Ticket

Link to Github Issue

Problem description

Provide context for the problem.

What's changed

Describe the approach used to solve the problem.
Summarize the changes made and its impact.

Checklist

@jaykru-tt jaykru-tt changed the title #0: add new utils and ttml ops #0: [tt-train] DRAFT add new ttml ops Feb 11, 2025
@@ -326,4 +326,43 @@ template tt::tt_metal::Tensor from_xtensor<uint32_t, DataType::UINT32>(
const XTensorToMeshVariant<uint32_t>& composer,
Layout layout);

ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank) {
auto logical_shape = t.get_logical_shape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sminakov-tt could you take a look at this function please?

@@ -326,4 +326,43 @@ template tt::tt_metal::Tensor from_xtensor<uint32_t, DataType::UINT32>(
const XTensorToMeshVariant<uint32_t>& composer,
Layout layout);

ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please throw if rank > 4 or 8?

auto logical_shape = t.get_logical_shape();
auto physical_shape = t.get_padded_shape();
auto t_rank = logical_shape.rank();
TT_FATAL(t_rank >= rank, "Cannot squeeze to rank {} from rank {}", rank, t_rank);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't use TT_FATAL in tt-train. Just check and throw exception.

@@ -4,6 +4,7 @@

#include "binary_ops.hpp"

#include <core/compute_kernel_config.hpp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use ""


// Overload supporting generic sum over multiple dimensions
tt::tt_metal::Tensor sum_moreh(
const tt::tt_metal::Tensor& t, std::optional<ttnn::SmallVector<int64_t>> dims, bool keep_dim) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't use optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not super familiar with best practice for std::optional in C++, so I checked the Google C++ style guide. They write that one should use optional for by-value parameters that are optional. I want to give an easy way to sum over all dims without having to construct that small vector as the called.

In this case I have, I think, 4 options:

  1. Use std::optional. nullopt passed -> sum over all dims.
  2. Use const * and treat nullptr as nothing passed -> sum over all dims. This is suitable for this specific case because we should avoid passing the vec by value
  3. Special case empty vector as the all dims case and use that as the default value for the parameter. This seems arbitrary to me and isn't generally applicable to all types.
  4. Overload without the optional parameter.

Is nullable const * okay in this case? And in general, what should we do for passing by value where there isn't an obvious value to signal the nothing passed case?

Thanks in advance for your guidance 😁

@@ -102,6 +103,42 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, const autograd::Tens
auto a_grad = ttnn::multiply(out->get_grad(), b->get_value());
auto b_grad = ttnn::multiply(out->get_grad(), a->get_value());

auto clamp_to_rank = [](const ttnn::Tensor& tensor, size_t rank) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jaykru-tt if you decided to add broadcasting it should work for all binary ops and be implemented in other way. So it should be functions which are independent from the exact op. If we decide to add this code to each op it would look pretty bad.

/* program_config */ std::nullopt,
/* activation */ std::nullopt,
/* compute_kernel_config */ core::ComputeKernelConfig::matmul(),
/* core_grid */ std::nullopt, // NOTE: I believe matmul will use the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already had comment in other pr. Please use our core grid. If we decide to use default parameter it should be used everywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Def will change this, just didn't change it yet since it won't make it into the other PR.

autograd::TensorPtr div(const autograd::TensorPtr& a, const autograd::TensorPtr& b);
autograd::TensorPtr div(const autograd::TensorPtr& a, float b);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we have mul(sclar, tensor) we should have a div too.

}

autograd::TensorPtr sum(const autograd::TensorPtr& tensor) {
auto out = autograd::create_tensor();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like sum op without dims parameter.

@@ -45,6 +46,43 @@ TEST_F(UnaryOpsTest, GlobalMean) {
}
}

TEST_F(UnaryOpsTest, Sum) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you must add tests for all your new ops:

  1. matmuls
  2. all new overloads of the mul, div and etc.
  3. broadcasting. (But Id make a separate pr for a broadcasting as it is a pretty complex feature)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants