-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#0: [tt-train] DRAFT add new ttml ops #17814
base: main
Are you sure you want to change the base?
Conversation
@@ -326,4 +326,43 @@ template tt::tt_metal::Tensor from_xtensor<uint32_t, DataType::UINT32>( | |||
const XTensorToMeshVariant<uint32_t>& composer, | |||
Layout layout); | |||
|
|||
ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank) { | |||
auto logical_shape = t.get_logical_shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sminakov-tt could you take a look at this function please?
@@ -326,4 +326,43 @@ template tt::tt_metal::Tensor from_xtensor<uint32_t, DataType::UINT32>( | |||
const XTensorToMeshVariant<uint32_t>& composer, | |||
Layout layout); | |||
|
|||
ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please throw if rank > 4 or 8?
auto logical_shape = t.get_logical_shape(); | ||
auto physical_shape = t.get_padded_shape(); | ||
auto t_rank = logical_shape.rank(); | ||
TT_FATAL(t_rank >= rank, "Cannot squeeze to rank {} from rank {}", rank, t_rank); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please don't use TT_FATAL in tt-train. Just check and throw exception.
@@ -4,6 +4,7 @@ | |||
|
|||
#include "binary_ops.hpp" | |||
|
|||
#include <core/compute_kernel_config.hpp> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use ""
|
||
// Overload supporting generic sum over multiple dimensions | ||
tt::tt_metal::Tensor sum_moreh( | ||
const tt::tt_metal::Tensor& t, std::optional<ttnn::SmallVector<int64_t>> dims, bool keep_dim) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please don't use optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not super familiar with best practice for std::optional in C++, so I checked the Google C++ style guide. They write that one should use optional for by-value parameters that are optional. I want to give an easy way to sum over all dims without having to construct that small vector as the called.
In this case I have, I think, 4 options:
- Use std::optional. nullopt passed -> sum over all dims.
- Use
const *
and treat nullptr as nothing passed -> sum over all dims. This is suitable for this specific case because we should avoid passing the vec by value - Special case empty vector as the all dims case and use that as the default value for the parameter. This seems arbitrary to me and isn't generally applicable to all types.
- Overload without the optional parameter.
Is nullable const *
okay in this case? And in general, what should we do for passing by value where there isn't an obvious value to signal the nothing passed case?
Thanks in advance for your guidance 😁
@@ -102,6 +103,42 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, const autograd::Tens | |||
auto a_grad = ttnn::multiply(out->get_grad(), b->get_value()); | |||
auto b_grad = ttnn::multiply(out->get_grad(), a->get_value()); | |||
|
|||
auto clamp_to_rank = [](const ttnn::Tensor& tensor, size_t rank) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jaykru-tt if you decided to add broadcasting it should work for all binary ops and be implemented in other way. So it should be functions which are independent from the exact op. If we decide to add this code to each op it would look pretty bad.
/* program_config */ std::nullopt, | ||
/* activation */ std::nullopt, | ||
/* compute_kernel_config */ core::ComputeKernelConfig::matmul(), | ||
/* core_grid */ std::nullopt, // NOTE: I believe matmul will use the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I already had comment in other pr. Please use our core grid. If we decide to use default parameter it should be used everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Def will change this, just didn't change it yet since it won't make it into the other PR.
autograd::TensorPtr div(const autograd::TensorPtr& a, const autograd::TensorPtr& b); | ||
autograd::TensorPtr div(const autograd::TensorPtr& a, float b); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we have mul(sclar, tensor) we should have a div too.
} | ||
|
||
autograd::TensorPtr sum(const autograd::TensorPtr& tensor) { | ||
auto out = autograd::create_tensor(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like sum op without dims parameter.
@@ -45,6 +46,43 @@ TEST_F(UnaryOpsTest, GlobalMean) { | |||
} | |||
} | |||
|
|||
TEST_F(UnaryOpsTest, Sum) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you must add tests for all your new ops:
- matmuls
- all new overloads of the mul, div and etc.
- broadcasting. (But Id make a separate pr for a broadcasting as it is a pretty complex feature)
Placeholder PR to save some work that spilled out of another PR:
Ticket
Link to Github Issue
Problem description
Provide context for the problem.
What's changed
Describe the approach used to solve the problem.
Summarize the changes made and its impact.
Checklist