diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.cpp b/tt-train/sources/ttml/core/tt_tensor_utils.cpp index 9f808e151f1..2aceb247fe1 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.cpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.cpp @@ -326,4 +326,43 @@ template tt::tt_metal::Tensor from_xtensor( const XTensorToMeshVariant& composer, Layout layout); +ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank) { + auto logical_shape = t.get_logical_shape(); + auto physical_shape = t.get_padded_shape(); + auto t_rank = logical_shape.rank(); + TT_FATAL(t_rank <= rank, "Cannot unsqueeze to rank {} from rank {}", rank, t_rank); + + ttnn::SmallVector result_logical_shape(rank); + ttnn::SmallVector result_physical_shape(rank); + std::fill(result_logical_shape.begin(), result_logical_shape.end(), 1); + std::fill(result_physical_shape.begin(), result_physical_shape.end(), 1); + + auto rank_diff = rank - t_rank; + std::copy(logical_shape.cbegin(), logical_shape.cend(), result_logical_shape.begin() + rank_diff); + std::copy(physical_shape.cbegin(), physical_shape.cend(), result_physical_shape.begin() + rank_diff); + return ttnn::reshape(t, ttnn::Shape{result_logical_shape}, ttnn::Shape{result_physical_shape}); +} + +ttnn::Tensor squeeze_to_rank(const ttnn::Tensor& t, size_t rank) { + auto logical_shape = t.get_logical_shape(); + auto physical_shape = t.get_padded_shape(); + auto t_rank = logical_shape.rank(); + TT_FATAL(t_rank >= rank, "Cannot squeeze to rank {} from rank {}", rank, t_rank); + + auto rank_diff = t_rank - rank; + bool leading_ones = + std::all_of(logical_shape.cbegin(), logical_shape.cbegin() + rank_diff, [](size_t dim) { return dim == 1; }); + TT_FATAL(leading_ones, "Cannot squeeze shape {} to rank {}", logical_shape, rank); + + ttnn::SmallVector result_logical_shape(rank); + ttnn::SmallVector result_physical_shape(rank); + std::fill(result_logical_shape.begin(), result_logical_shape.end(), 1); + std::fill(result_physical_shape.begin(), result_physical_shape.end(), 1); + + std::copy(logical_shape.cbegin() + rank_diff, logical_shape.cend(), result_logical_shape.begin()); + std::copy(physical_shape.cbegin() + rank_diff, physical_shape.cend(), result_physical_shape.begin()); + + return ttnn::reshape(t, ttnn::Shape{result_logical_shape}, ttnn::Shape{result_physical_shape}); +} + } // namespace ttml::core diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.hpp b/tt-train/sources/ttml/core/tt_tensor_utils.hpp index 3035e7eca1e..7fdc9caaf49 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.hpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.hpp @@ -83,4 +83,10 @@ tt::tt_metal::Tensor from_xtensor( const XTensorToMeshVariant& composer, Layout layout = Layout::TILE); +// Unsqueeze tensor to specified rank by adding leading dimensions of size 1 +ttnn::Tensor unsqueeze_to_rank(const ttnn::Tensor& t, size_t rank); + +// Squeeze tensor to specified rank by removing leading dimensions of size 1 +ttnn::Tensor squeeze_to_rank(const ttnn::Tensor& t, size_t rank); + } // namespace ttml::core diff --git a/tt-train/sources/ttml/ops/binary_ops.cpp b/tt-train/sources/ttml/ops/binary_ops.cpp index 511f3e8a0a5..e6f33e0143d 100644 --- a/tt-train/sources/ttml/ops/binary_ops.cpp +++ b/tt-train/sources/ttml/ops/binary_ops.cpp @@ -4,6 +4,7 @@ #include "binary_ops.hpp" +#include #include #include #include @@ -102,6 +103,42 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, const autograd::Tens auto a_grad = ttnn::multiply(out->get_grad(), b->get_value()); auto b_grad = ttnn::multiply(out->get_grad(), a->get_value()); + auto clamp_to_rank = [](const ttnn::Tensor& tensor, size_t rank) { + auto tensor_rank = tensor.logical_shape().rank(); + if (tensor_rank == rank) { + return tensor; + } else if (tensor_rank > rank) { + return ttml::core::squeeze_to_rank(tensor, rank); + } else { + return ttml::core::unsqueeze_to_rank(tensor, rank); + } + }; + + auto logical_suffixes_match = [](const ttnn::Tensor& a, const ttnn::Tensor& b) { + auto a_shape = a.get_logical_shape(); + auto b_shape = b.get_logical_shape(); + + auto suffix_len = std::min(a_shape.size(), b_shape.size()); + for (auto i = -1; i >= -suffix_len; i--) { + if (a_shape[i] != b_shape[i]) { + return false; + } + } + return true; + }; + + if (a->get_value().logical_shape().rank() != a_grad.logical_shape().rank()) { + if (logical_suffixes_match(a->get_value(), a_grad)) { + a_grad = clamp_to_rank(a_grad, a->get_value().logical_shape().rank()); + } + } + + if (b->get_value().logical_shape().rank() != b_grad.logical_shape().rank()) { + if (logical_suffixes_match(b->get_value(), b_grad)) { + b_grad = clamp_to_rank(b_grad, b->get_value().logical_shape().rank()); + } + } + a->add_grad(a_grad); b->add_grad(b_grad); }; @@ -124,6 +161,14 @@ autograd::TensorPtr operator*(const autograd::TensorPtr& a, float b) { return out; } +autograd::TensorPtr operator*(float a, const autograd::TensorPtr& b) { + return b * a; +} + +autograd::TensorPtr operator/(const autograd::TensorPtr& a, float b) { + return a * (1.F / b); +} + autograd::TensorPtr operator/(const autograd::TensorPtr& a, const autograd::TensorPtr& b) { auto out = autograd::create_tensor(); @@ -155,12 +200,70 @@ autograd::TensorPtr mul(const autograd::TensorPtr& a, const autograd::TensorPtr& return a * b; } +autograd::TensorPtr mul(const autograd::TensorPtr& a, float b) { + return a * b; +} + +autograd::TensorPtr mul(float a, const autograd::TensorPtr& b) { + return b * a; +} + autograd::TensorPtr div(const autograd::TensorPtr& a, const autograd::TensorPtr& b) { return a / b; } -autograd::TensorPtr mul(const autograd::TensorPtr& a, float b) { - return a * b; +autograd::TensorPtr div(const autograd::TensorPtr& a, float b) { + return a / b; +} + +tt::tt_metal::Tensor ttnn_matmul( + const tt::tt_metal::Tensor& a, const tt::tt_metal::Tensor& b, bool transpose_a, bool transpose_b) { + return ttnn::matmul( + a, + b, + transpose_a, + transpose_b, + /* memory_config */ std::nullopt, + /* dtype */ std::nullopt, + /* program_config */ std::nullopt, + /* activation */ std::nullopt, + /* compute_kernel_config */ core::ComputeKernelConfig::matmul(), + /* core_grid */ std::nullopt, // NOTE: I believe matmul will use the + // core grid for the device it ends up + // running on, but should confirm. + /* output_tile */ std::nullopt); +} + +autograd::TensorPtr matmul( + const autograd::TensorPtr& a, const autograd::TensorPtr& b, bool transpose_a, bool transpose_b) { + auto out = autograd::create_tensor(); + out->set_value(ttnn_matmul(a->get_value(), b->get_value(), transpose_a, transpose_b)); + + autograd::GradFunction grad = [a, b, out, transpose_a, transpose_b]() { + // For loss function L and matmul C = AB: + // dL/dA = dL/dC * B^T + // dL/dB = A^T * dL/dC + + // where L is the loss function + auto grad_a = ttnn_matmul( + out->get_grad(), + b->get_value(), + /* transpose_a */ transpose_a, + /* transpose_b */ !transpose_b); + auto grad_b = ttnn_matmul( + a->get_value(), + out->get_grad(), + /* transpose_a */ !transpose_a, + /* transpose_b */ transpose_b); + + a->add_grad(grad_a); + b->add_grad(grad_b); + }; + + auto links = autograd::get_links(a, b); + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); + + return out; } } // namespace ttml::ops diff --git a/tt-train/sources/ttml/ops/binary_ops.hpp b/tt-train/sources/ttml/ops/binary_ops.hpp index 862e318f1a2..2a4def45b30 100644 --- a/tt-train/sources/ttml/ops/binary_ops.hpp +++ b/tt-train/sources/ttml/ops/binary_ops.hpp @@ -12,14 +12,21 @@ autograd::TensorPtr operator+(const autograd::TensorPtr& a, const autograd::Auto autograd::TensorPtr operator+(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr operator*(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr operator*(const autograd::TensorPtr& a, float b); +autograd::TensorPtr operator*(float a, const autograd::TensorPtr& b); autograd::TensorPtr operator-(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr operator/(const autograd::TensorPtr& a, const autograd::TensorPtr& b); +autograd::TensorPtr operator/(const autograd::TensorPtr& a, float b); autograd::TensorPtr add(const autograd::TensorPtr& a, const autograd::AutocastTensor& b); autograd::TensorPtr add(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr sub(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr mul(const autograd::TensorPtr& a, const autograd::TensorPtr& b); autograd::TensorPtr mul(const autograd::TensorPtr& a, float b); +autograd::TensorPtr mul(float a, const autograd::TensorPtr& b); autograd::TensorPtr div(const autograd::TensorPtr& a, const autograd::TensorPtr& b); +autograd::TensorPtr div(const autograd::TensorPtr& a, float b); + +autograd::TensorPtr matmul( + const autograd::TensorPtr& a, const autograd::TensorPtr& b, bool transpose_a, bool transpose_b); } // namespace ttml::ops diff --git a/tt-train/sources/ttml/ops/unary_ops.cpp b/tt-train/sources/ttml/ops/unary_ops.cpp index e2e76fb881c..22a0f5100a6 100644 --- a/tt-train/sources/ttml/ops/unary_ops.cpp +++ b/tt-train/sources/ttml/ops/unary_ops.cpp @@ -140,4 +140,36 @@ autograd::TensorPtr broadcast_batch(const autograd::TensorPtr& tensor, uint32_t return out; } +autograd::TensorPtr sqrt(const autograd::TensorPtr& tensor) { + auto out = autograd::create_tensor(); + auto sqrt_tensor = ttnn::sqrt(tensor->get_value()); + out->set_value(sqrt_tensor); + autograd::GradFunction grad = [&tensor, &out, &sqrt_tensor]() { + // dL/dx = dL/d(sqrt(x)) * 1/(2*sqrt(x)) + auto grad = ttnn::divide(out->get_grad(), ttnn::multiply(sqrt_tensor, 2.F)); + tensor->add_grad(grad); + }; + auto links = autograd::get_links(tensor); + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); + return out; +} + +autograd::TensorPtr sum(const autograd::TensorPtr& tensor) { + auto out = autograd::create_tensor(); + out->set_value(ttml::ttnn_fixed::sum_moreh(tensor->get_value())); + + autograd::GradFunction grad = [tensor, out]() { + // Distribute the gradient to each element in the original tensor + auto in_shape = tensor->get_value().get_logical_shape(); + auto grad_shape = out->get_grad().get_logical_shape(); + + auto unsqueezed_grad = ttml::core::unsqueeze_to_rank(out->get_grad(), in_shape.rank()); + auto grad_broadcast = ttnn::repeat(unsqueezed_grad, in_shape); + tensor->add_grad(grad_broadcast); + }; + + auto links = autograd::get_links(tensor); + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); + return out; +} } // namespace ttml::ops diff --git a/tt-train/sources/ttml/ops/unary_ops.hpp b/tt-train/sources/ttml/ops/unary_ops.hpp index 669ee04233b..33e964ffe3e 100644 --- a/tt-train/sources/ttml/ops/unary_ops.hpp +++ b/tt-train/sources/ttml/ops/unary_ops.hpp @@ -15,4 +15,6 @@ autograd::TensorPtr sum(const autograd::TensorPtr& tensor); autograd::TensorPtr broadcast_batch(const autograd::TensorPtr& tensor, uint32_t new_batch_dim); autograd::TensorPtr log_softmax(const autograd::TensorPtr& tensor, int dim); autograd::TensorPtr log_softmax_moreh(const autograd::TensorPtr& tensor, int dim); +autograd::TensorPtr sqrt(const autograd::TensorPtr& tensor); + } // namespace ttml::ops diff --git a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp index ad818f6040f..6b7a38118ad 100644 --- a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp +++ b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp @@ -73,6 +73,15 @@ tt::tt_metal::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep /* device_compute_kernel_config */ core::ComputeKernelConfig::precise()); return res; } + +// Overload supporting generic sum over multiple dimensions +tt::tt_metal::Tensor sum_moreh( + const tt::tt_metal::Tensor& t, std::optional> dims, bool keep_dim) { + tt::tt_metal::Tensor res = + ttnn::moreh_sum(t, dims, keep_dim, std::nullopt, std::nullopt, core::ComputeKernelConfig::precise()); + return res; +} + tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim) { return ttnn::sum(t, dim, keep_dim, std::nullopt, core::ComputeKernelConfig::precise()); } diff --git a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp index c8a62d981bc..e28c6ea396c 100644 --- a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp +++ b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp @@ -19,5 +19,9 @@ tt::tt_metal::Tensor mean_moreh(const tt::tt_metal::Tensor& t, int dim, bool kee tt::tt_metal::Tensor mean_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim); tt::tt_metal::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim); +tt::tt_metal::Tensor sum_moreh( + const tt::tt_metal::Tensor& t, + std::optional> dims = std::nullopt, + bool keep_dim = false); tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim); } // namespace ttml::ttnn_fixed diff --git a/tt-train/tests/ops/unary_ops_test.cpp b/tt-train/tests/ops/unary_ops_test.cpp index 90c2afeac0d..3208a9a12a6 100644 --- a/tt-train/tests/ops/unary_ops_test.cpp +++ b/tt-train/tests/ops/unary_ops_test.cpp @@ -11,6 +11,7 @@ #include "autograd/auto_context.hpp" #include "autograd/tensor.hpp" #include "core/tt_tensor_utils.hpp" +#include "core/xtensor_utils.hpp" class UnaryOpsTest : public ::testing::Test { protected: @@ -45,6 +46,43 @@ TEST_F(UnaryOpsTest, GlobalMean) { } } +TEST_F(UnaryOpsTest, Sum) { + xt::xarray test_vector = {{1.F, 2.F, 3.F, 4.F}, {1.F, 2.F, 3.F, 4.F}}; + auto test_tensor_ptr = + ttml::autograd::create_tensor(ttml::core::from_xtensor(test_vector, &ttml::autograd::ctx().get_device())); + + auto result = ttml::ops::sum(test_tensor_ptr); + auto result_vector = ttml::core::to_xtensor(result->get_value()); + + ASSERT_TRUE(xt::allclose(result_vector, xt::sum(test_vector), 1e-5F)); + + result->backward(); + auto test_tensor_grad = ttml::core::to_xtensor(test_tensor_ptr->get_grad()); + + ASSERT_TRUE(xt::allclose(xt::ones_like(test_vector), test_tensor_grad, 1e-5F)); +} + +TEST_F(UnaryOpsTest, Sqrt) { + xt::xarray test_vector = {{1.F, 2.F, 3.F, 4.F}, {1.F, 2.F, 3.F, 4.F}}; + auto test_tensor_ptr = + ttml::autograd::create_tensor(ttml::core::from_xtensor(test_vector, &ttml::autograd::ctx().get_device())); + + auto result = ttml::ops::sqrt(test_tensor_ptr); + auto result_vector = ttml::core::to_xtensor(result->get_value()); + + std::cout << "result_vector: " << result_vector << std::endl; + std::cout << "test_vector: " << test_vector << std::endl; + std::cout << "xt::sqrt(test_vector): " << xt::sqrt(test_vector) << std::endl; + + ASSERT_TRUE(xt::allclose(result_vector, xt::sqrt(test_vector), 1e-2F)); + + // FIXME(jaykru-tt): add grad test for sqrt + // result->backward(); + // auto test_tensor_grad = ttml::core::to_xtensor(test_tensor_ptr->get_grad()); + + // ASSERT_TRUE(xt::allclose(xt::ones_like(test_vector), test_tensor_grad)); +} + TEST_F(UnaryOpsTest, LogSoftmax) { auto* device = &ttml::autograd::ctx().get_device(); std::vector test_data = {-0.1F, -0.2F, -0.3F, -0.4F, 0.F, -0.2F, -0.3F, -0.4F};