From 503ca1cf1f97fde9399a55c678a9ba5b3815d94e Mon Sep 17 00:00:00 2001 From: Jack Cai Date: Fri, 7 Feb 2025 20:37:13 +0000 Subject: [PATCH 01/25] added ccl async llama perf and correctness measurement tests to CI --- .../tg/ccl/test_ccl_async_TG_llama_nightly.py | 1 + .../ccl/test_all_gather_TG_post_commit.py | 9 ++++++ .../operations/ccl/test_ccl_async_TG_llama.py | 30 ++++++++++++++----- 3 files changed, 33 insertions(+), 7 deletions(-) create mode 120000 tests/nightly/tg/ccl/test_ccl_async_TG_llama_nightly.py diff --git a/tests/nightly/tg/ccl/test_ccl_async_TG_llama_nightly.py b/tests/nightly/tg/ccl/test_ccl_async_TG_llama_nightly.py new file mode 120000 index 00000000000..18ed2ca2998 --- /dev/null +++ b/tests/nightly/tg/ccl/test_ccl_async_TG_llama_nightly.py @@ -0,0 +1 @@ +../../../ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py \ No newline at end of file diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py index a476163c8d5..7f37600028a 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py @@ -14,6 +14,7 @@ teardown_fabric_interface, create_global_semaphore_with_same_address, ) +from models.perf.benchmarking_utils import BenchmarkProfiler def report_mismatches(golden, actual, max_printable=None): @@ -64,6 +65,7 @@ def run_with_trace( n_buffer=None, num_iter=20, use_all_gather_async=False, + profiler=BenchmarkProfiler(), ): # Compile Run logger.info("Compiling model") @@ -131,10 +133,15 @@ def run_with_trace( # Run the op logger.info("Starting Trace perf test...") + profiler.start("all-gather-async-trace") ttnn.execute_trace(mesh_device, trace_id, blocking=False) ttnn.release_trace(mesh_device, trace_id) for d in mesh_device.get_devices(): ttnn.synchronize_device(d) + profiler.end("all-gather-async-trace") + logger.info(f"Time taken: {profiler.get_duration('all-gather-async-trace')} s") + logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter} s") + logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter * 1e6} us") return tt_out_tensor @@ -160,6 +167,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( tile=(32, 32), trace_mode=False, debug=False, + profiler=BenchmarkProfiler(), # New all-gather-async and persistent fabric params use_all_gather_async=False, enable_persistent_fabric=False, @@ -270,6 +278,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( all_gather_topology=ttnn.Topology.Linear, num_iter=num_iters, use_all_gather_async=use_all_gather_async, + profiler=profiler, ) else: diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py index c1673280601..e413e3c1eef 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py @@ -23,6 +23,7 @@ from tests.ttnn.unit_tests.operations.ccl.test_all_reduce_async import ( run_all_reduce_with_mesh_tensor_along_row, ) +from models.perf.benchmarking_utils import BenchmarkProfiler PREFETCHER_NOC1_RING = [ @@ -79,22 +80,25 @@ def get_core_range_set(output_core_grid): "num_devices, num_links", [ (4, 3), - (4, 2), - (4, 1), ], ) @pytest.mark.parametrize( "input_dtype", [ - ttnn.bfloat16, ttnn.bfloat8_b, ], ) +@pytest.mark.parametrize( + "num_iters", + [ + 5000, + ], +) @pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) @pytest.mark.parametrize( - "tensor_mem_layout, output_shape, dim, input_shard_shape,input_shard_grid,output_shard_shape, output_shard_grid, layout", + "tensor_mem_layout, output_shape, dim, input_shard_shape,input_shard_grid,output_shard_shape, output_shard_grid, layout, perf_target_us", ( - ( # AllGather after SDPA (~160 us) + ( # AllGather after SDPA ttnn.TensorMemoryLayout.HEIGHT_SHARDED, (1, 32, 32, 128), 1, @@ -108,8 +112,9 @@ def get_core_range_set(output_core_grid): } ), ttnn.TILE_LAYOUT, + 40, ), - ( # AllGather after Binary Mult+Silu (~160 us) + ( # AllGather after Binary Mult+Silu ttnn.TensorMemoryLayout.WIDTH_SHARDED, (1, 1, 32, 3840), 3, @@ -118,6 +123,7 @@ def get_core_range_set(output_core_grid): (32, 160), get_core_range_set(PREFETCHER_NOC1_RING), ttnn.TILE_LAYOUT, + 32, ), ), ) @@ -143,7 +149,8 @@ def test_line_all_gather_sharded_on_TG_rows_llama( function_level_defaults, enable_async, replication_factor, - num_iters=100, + num_iters, + perf_target_us, ): if len(mesh_device.get_devices()) != 32: pytest.skip("Not TG!") @@ -162,6 +169,8 @@ def test_line_all_gather_sharded_on_TG_rows_llama( else: output_shard_spec = None + profiler = BenchmarkProfiler() + run_line_all_gather_on_TG_with_mesh_tensor_along_rows( mesh_device, num_devices, @@ -180,6 +189,7 @@ def test_line_all_gather_sharded_on_TG_rows_llama( output_shard_spec=output_shard_spec, num_all_gather_instances=replication_factor, cluster_axis=1, + profiler=profiler, trace_mode=True, use_all_gather_async=True, enable_persistent_fabric=True, @@ -187,6 +197,12 @@ def test_line_all_gather_sharded_on_TG_rows_llama( teardown_persistent_fabric=True, ) + latency_us = profiler.get_duration("all-gather-async-trace") / num_iters * 1e6 + if perf_target_us is not None: + assert ( + latency_us < perf_target_us + ), f"Measured latency {latency_us} us is greater than target {perf_target_us} us" + @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( From ccd3f982d60d2a0f4144597ee50cca04f121c043 Mon Sep 17 00:00:00 2001 From: Jack Cai Date: Mon, 10 Feb 2025 16:58:46 +0000 Subject: [PATCH 02/25] updated ccl perf target for llama all gather async --- .../ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py index e413e3c1eef..fe967467e14 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py @@ -112,7 +112,7 @@ def get_core_range_set(output_core_grid): } ), ttnn.TILE_LAYOUT, - 40, + 32, ), ( # AllGather after Binary Mult+Silu ttnn.TensorMemoryLayout.WIDTH_SHARDED, @@ -123,7 +123,7 @@ def get_core_range_set(output_core_grid): (32, 160), get_core_range_set(PREFETCHER_NOC1_RING), ttnn.TILE_LAYOUT, - 32, + 25, ), ), ) From 0c9ccd79e6fcb7fdb170da3bed433f75ae0d3a48 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Mon, 3 Feb 2025 11:17:51 -0600 Subject: [PATCH 03/25] WIP adding infra for the new all reduce. --- .../operations/ccl/test_new_all_reduce.py | 237 +++++++++++++ .../experimental/ccl/CMakeLists.txt | 2 + .../ccl/all_reduce_async/all_reduce_async.cpp | 23 ++ .../ccl/all_reduce_async/all_reduce_async.hpp | 10 + .../all_reduce_async_pybind.cpp | 30 ++ .../device/all_reduce_async_op.cpp | 243 +++++++++++++ .../device/all_reduce_async_op.hpp | 132 +++++++ ..._reduce_async_program_minimal_variants.cpp | 333 ++++++++++++++++++ 8 files changed, 1010 insertions(+) create mode 100644 tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py new file mode 100644 index 00000000000..daf71be854c --- /dev/null +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import math +from loguru import logger +import ttnn +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc +from models.utility_functions import skip_for_grayskull +from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( + create_and_load_sub_device_manager_with_fabric_interface, + teardown_fabric_interface, + create_global_semaphore_with_same_address, +) + +from tests.tt_eager.python_api_testing.unit_testing.misc.test_matmul_1d_gather_in0 import ( + num_cores_to_rectangle_grid, + round_up, +) + + +def run_all_reduce_impl( + mesh_device, + output_shape, + cluster_axis, + input_dtype, + num_links, + num_iters=1, + enable_async=False, + trace_mode=False, +): + cluster_shape = (8, 4) + + create_persistent_fabric = True + teardown_persistent_fabric = True + enable_persistent_fabric = True + if num_iters < 1: + pytest.fail("num_iters must be >= 1") + # Use Async mode based on test input config + mesh_device.enable_async(enable_async) + + if enable_async: + logger.info(f"Using Async Mode for All Gather Op Dispatch") + + ################################## + ##### Set up fabric stuff + ################################## + compute_grid_size = mesh_device.compute_with_storage_grid_size() + ccl_sub_device_crs = ttnn.CoreRangeSet( + {ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(compute_grid_size.x - 1, compute_grid_size.y - 1))} + ) + worker_sub_device = ttnn.SubDevice( + [ + ccl_sub_device_crs, + ] + ) + worker_sub_device_id = ttnn.SubDeviceId(0) + sub_device_stall_group = [worker_sub_device_id] + if create_persistent_fabric: + mesh_sub_device_manager_id = create_and_load_sub_device_manager_with_fabric_interface( + mesh_device, [worker_sub_device], 0, 0, enable_persistent_fabric + ) + mesh_device.set_sub_device_stall_group(sub_device_stall_group) + + # create global semaphore handles + ccl_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) + + logger.info(f"Output shape: {output_shape}") + + input_tensor_mesh_list = [] + output_tensor_goldens_list = [] + tt_outs = [] + + try: + ################################## + ##### Set up input tensors/configs + ################################## + + ##### FF2 Case ##### + num_cores = 24 # matmul ring + M, N = output_shape[2:] + N_per_shard = round_up(math.ceil(N / num_cores), ttnn.TILE_SIZE) + input_shape = [*cluster_shape, M, N] + + input_grid = num_cores_to_rectangle_grid(num_cores, mesh_device) + CORE_RANGE = [(x, y) for y in range(input_grid[1]) for x in range(input_grid[0])] + core_range_set = ttnn.CoreRangeSet( + [ + ttnn.CoreRange( + ttnn.CoreCoord(x, y), + ttnn.CoreCoord(x, y), + ) + for x, y in CORE_RANGE + ] + ) + input_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + core_range_set, + [M, N_per_shard], + ttnn.ShardOrientation.ROW_MAJOR, + ), + ) + output_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + core_range_set, + [M, N_per_shard * 4], + ttnn.ShardOrientation.ROW_MAJOR, + ), + ) + + input_tensor = torch.randn(input_shape) + tt_input_tensor = ttnn.from_torch( + input_tensor, + device=mesh_device, + layout=ttnn.TILE_LAYOUT, + dtype=input_dtype, + memory_config=input_mem_config, + mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(0, 1), mesh_shape=cluster_shape), + ) + + ################################## + ##### Run the op + ################################## + + def run_op(): + outs = [] + for i in range(num_iters): + out = ttnn.experimental.all_reduce_async( + tt_input_tensor, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + multi_device_global_semaphore=ccl_semaphore_handles, + memory_config=output_mem_config, + topology=ttnn.Topology.Linear, + num_links=num_links, + subdevice_id=worker_sub_device_id, + ) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + outs.append(out) + + return outs + + # ##### Compile Model ##### + # logger.info("Compiling model") + # tt_outs = run_op() + + # ##### Capture Trace ##### + # logger.info("Capturing trace") + + # trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + # tt_outs = run_op() + # ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + + # ##### Run Trace ##### + # logger.info("Running trace") + # ttnn.execute_trace(mesh_device, trace_id, blocking=False) + + tt_outs = run_op() + + ################################## + ##### Validation + ################################## + for tensor_index in range(len(tt_outs)): + tt_out_tensor = tt_outs[tensor_index] + output_tensor = output_tensor_goldens_list[tensor_index] + for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): + tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + logger.info(f"Checking for device {t.device().id()}") + + if input_dtype == ttnn.bfloat16: + eq, output = comp_equal(tt_output_tensor, output_tensor) + else: + eq, output = comp_pcc(tt_output_tensor, output_tensor) + if not eq: + logger.error(f"output mismatch for tensor {i}") + assert eq, f"{i} FAILED: {output}" + finally: + if enable_persistent_fabric and teardown_persistent_fabric: + mesh_device.reset_sub_device_stall_group() + teardown_fabric_interface(mesh_device) + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "output_shape, cluster_axis, num_links", + [ + ([1, 1, 32, 3840], 1, 1), + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + # ttnn.bfloat16, + ttnn.bfloat8_b, + ], +) +@pytest.mark.parametrize("num_iters", [1]) +@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize( + "device_params", + [{"trace_region_size": 23887872}], + indirect=True, +) +@pytest.mark.parametrize( + "mesh_device", + [ + (8, 4), + ], + indirect=True, +) +def test_all_reduce( + mesh_device, + output_shape, + cluster_axis, + input_dtype, + num_links, + num_iters, + enable_async, + use_program_cache, + function_level_defaults, +): + run_all_reduce_impl( + mesh_device, + output_shape, + cluster_axis, + input_dtype, + num_links, + num_iters=num_iters, + enable_async=enable_async, + ) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt b/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt index e80883ac5f7..e8af00e435a 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt @@ -19,6 +19,8 @@ set(CCL_EXPERIMENTAL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/device/all_gather_async_program_minimal_variants.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce_async/all_reduce_async.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce_async/all_reduce_async_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce_async/device/all_reduce_async_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp CACHE INTERNAL "CCL Experimental sources to reuse in ttnn build" ) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp index 1a94281724e..41cad729290 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp @@ -6,6 +6,7 @@ #include "cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp" #include "ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp" +#include "device/all_reduce_async_op.hpp" #include "cpp/ttnn/global_semaphore.hpp" namespace ttnn::operations::experimental::ccl { @@ -106,4 +107,26 @@ ttnn::Tensor ExecuteAllReduceAsync::invoke( true); } +ttnn::Tensor ExecuteAllReduceAsync::invoke( + const ttnn::Tensor& input_tensor, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const std::optional& memory_config, + ttnn::ccl::Topology topology, + const std::optional num_preferred_links, + std::optional worker_subdevice_id_opt) { + MemoryConfig out_memory_config = memory_config.value_or(input_tensor.memory_config()); + return ttnn::operations::experimental::ccl::all_reduce_async( + input_tensor, + cluster_axis, + mesh_device, + topology, + multi_device_global_semaphore, + out_memory_config, + num_preferred_links, + worker_subdevice_id_opt, + true); +} + } // namespace ttnn::operations::experimental::ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp index b0b80451b8e..b6ae2ef72cc 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp @@ -41,6 +41,16 @@ struct ExecuteAllReduceAsync { ttnn::ccl::Topology topology, const std::optional num_preferred_links, std::optional worker_subdevice_id_opt); + + static ttnn::Tensor invoke( + const ttnn::Tensor& input_tensor, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const std::optional& memory_config, + ttnn::ccl::Topology topology, + const std::optional num_preferred_links, + std::optional worker_subdevice_id_opt); }; } // namespace ccl diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp index ec91d88a2be..bf3e4ec1e6f 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp @@ -94,6 +94,36 @@ void bind_all_reduce(pybind11::module& module, const ccl_operation_t& operation, py::arg("memory_config") = std::nullopt, py::arg("topology") = ttnn::ccl::Topology::Linear, py::arg("num_links") = std::nullopt, + py::arg("subdevice_id") = std::nullopt}, + + ttnn::pybind_overload_t{ + [](const ccl_operation_t& self, + const ttnn::Tensor& input_tensor, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const ttnn::MemoryConfig& memory_config, + ttnn::ccl::Topology topology, + const std::optional num_links, + std::optional worker_subdevice_id_opt) -> ttnn::Tensor { + return self( + input_tensor, + cluster_axis, + mesh_device, + multi_device_global_semaphore, + memory_config, + topology, + num_links, + worker_subdevice_id_opt); + }, + py::arg("input_tensor"), + py::arg("cluster_axis"), + py::arg("mesh_device"), + py::arg("multi_device_global_semaphore"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("topology") = ttnn::ccl::Topology::Linear, + py::arg("num_links") = std::nullopt, py::arg("subdevice_id") = std::nullopt}); } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp new file mode 100644 index 00000000000..06e594066d9 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp @@ -0,0 +1,243 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp" +#include "all_reduce_async_op.hpp" +#include "ttnn/operations/math.hpp" +#include "cpp/ttnn/global_semaphore.hpp" + +#include + +#include "ttnn/tensor/tensor_utils.hpp" + +namespace ttnn { +namespace ccl { +namespace all_reduce_detail { + +AllReduceAsync create_all_reduce_async_struct( + const Tensor& input_tensor, + const uint32_t num_links, + const std::optional& memory_config, + const std::vector& devices, + const ttnn::ccl::Topology topology, + const std::vector& semaphores, + std::optional& sub_device_id, + bool enable_persistent_fabric_mode) { + uint32_t num_devices = devices.size(); + + std::optional forward_device = std::nullopt; + std::optional backward_device = std::nullopt; + std::optional semaphore = std::nullopt; + uint32_t device_index = 0; // Initialize device index + for (uint32_t i = 0; i < num_devices; ++i) { + if (devices.at(i) == input_tensor.device()) { + device_index = i; + semaphore = semaphores.at(i); // Get raw pointer + if (i != 0) { + backward_device = devices.at(i - 1); + } + if (i != num_devices - 1) { + forward_device = devices.at(i + 1); + } + } + } + + return ttnn::AllReduceAsync{ + forward_device, + backward_device, + num_links, + num_devices, + device_index, + memory_config.value_or(input_tensor.memory_config()), + topology, + semaphore.value(), + sub_device_id, + enable_persistent_fabric_mode}; +} + +uint32_t find_scatter_dim(const ttnn::SimpleShape& input_tensor_padded_shape, size_t num_workers) { + // iterate until we find a dimension that is divisible by num_workers + TT_FATAL(input_tensor_padded_shape.size() == 4, "Expected input tensor to have 4 dimensions"); + ttnn::SimpleShape input_tensor_shape_in_tiles{ + input_tensor_padded_shape[0], + input_tensor_padded_shape[1], + input_tensor_padded_shape[2] / tt::constants::TILE_HEIGHT, + input_tensor_padded_shape[3] / tt::constants::TILE_WIDTH}; + for (uint32_t dim = 0; dim < 4; ++dim) { + if (input_tensor_shape_in_tiles[dim] % num_workers == 0) { + tt::log_debug( + "Found scatter dimension {} for input tensor with padded shape {}", dim, input_tensor_padded_shape); + return dim; + } + } + TT_THROW( + "No scatter dimension found for input tensor with padded shape {} and num_workers {}", + input_tensor_padded_shape, + num_workers); +} + +} // namespace all_reduce_detail +} // namespace ccl + +void AllReduceAsync::validate(const std::vector& input_tensors) const { + TT_FATAL(input_tensors.size() == 1, "Error, Input tensor size should be 1 but has {}", input_tensors.size()); + const auto& input_tensor = input_tensors[0]; + const auto& layout = input_tensors[0].get_layout(); + const auto& dtype = input_tensors[0].get_dtype(); + const auto& page_size = input_tensors[0].buffer()->page_size(); + TT_FATAL(page_size % input_tensors[0].buffer()->alignment() == 0, "All Gather currently requires aligned pages"); + + TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to all_reduce need to be on device!"); + TT_FATAL(input_tensor.buffer() != nullptr, "Operands to all_reduce need to be allocated in buffers on device!"); + TT_FATAL(this->num_links > 0, "Error, num_links should be more than 0 but has {}", this->num_links); + TT_FATAL( + this->num_links <= input_tensor.device()->compute_with_storage_grid_size().y, + "Worker cores used by links are parallelizaed over rows"); + + TT_FATAL( + input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED || + input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED || + input_tensor.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED || + input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, + "Unsupported memory layout {}.", + input_tensor.memory_config().memory_layout); +} + +static void validate_output_tensor_allocation(const std::vector& output_tensors) { + for (const auto& output_tensor : output_tensors) { + const auto& buffers = output_tensor.buffers(); + const auto first_address = buffers.front()->address(); + TT_FATAL( + std::all_of( + buffers.begin(), + buffers.end(), + [&first_address](const auto& buffer) { + return buffer != nullptr && buffer->address() == first_address; + }), + "Output buffers for all_reduce async must be lock-step allocated but some of the tensors were allocated at " + "different addresses across devices."); + } +} + +std::vector AllReduceAsync::compute_output_specs(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors[0]; + auto shape = input_tensor.get_padded_shape(); // TODO: Replace with get_logical_shape() + shape[3] *= this->ring_size; + return {TensorSpec( + shape, + TensorLayout(input_tensor.get_dtype(), input_tensor.get_tensor_spec().page_config(), output_mem_config))}; +} + +operation::ProgramWithCallbacks AllReduceAsync::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { + tt::log_debug(tt::LogOp, "DEBUG: create_program is called"); + + auto input_tensor_shape = input_tensors[0].get_padded_shape(); + auto input_tensor_buffer_layout = input_tensors[0].buffer()->buffer_layout(); + auto input_tensor_page_layout = input_tensors[0].layout(); + + auto input_tensor_memory_config = input_tensors[0].memory_config(); + auto output_tensor_memory_config = output_tensors[0].memory_config(); + uint32_t input_shard_num_cores = input_tensor_memory_config.shard_spec->grid.num_cores(); + uint32_t output_shard_num_cores = output_tensor_memory_config.shard_spec->grid.num_cores(); + + tt::log_debug(tt::LogOp, "input_tensor_shape: {}", input_tensor_shape); + tt::log_debug(tt::LogOp, "input_tensor_memory_config: {}", input_tensor_memory_config); + tt::log_debug(tt::LogOp, "output_tensor_memory_config: {}", output_tensor_memory_config); + tt::log_debug(tt::LogOp, "input_shard_num_cores: {}", input_shard_num_cores); + tt::log_debug(tt::LogOp, "output_shard_num_cores: {}", output_shard_num_cores); + tt::log_debug( + tt::LogOp, "input_tensor_memory_config.shard_spec->shape: {}", input_tensor_memory_config.shard_spec->shape); + tt::log_debug( + tt::LogOp, "output_tensor_memory_config.shard_spec->shape: {}", output_tensor_memory_config.shard_spec->shape); + + tt::log_info(tt::LogOp, "Running TG Llama specific all_reduce_async_minimal_multi_core_with_workers"); + return all_reduce_async_minimal_multi_core_with_workers( + input_tensors[0], + this->forward_device, + this->backward_device, + output_tensors[0], + this->num_links, + this->ring_size, + this->ring_index, + this->topology, + this->semaphore, + this->sub_device_id, + this->enable_persistent_fabric_mode); +} + +const operation::Hash AllReduceAsync::compute_program_hash(const std::vector& input_tensors) const { + return operation::hash_operation( + this->num_links, this->ring_size, this->ring_index, this->output_mem_config, this->topology); +} + +namespace operations { +namespace experimental { +namespace ccl { + +Tensor all_reduce_async( + const Tensor& input_tensor, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const ttnn::ccl::Topology topology, + const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const std::optional& memory_config, + const std::optional num_preferred_links, + std::optional subdevice_id, + bool enable_persistent_fabric_mode) { + TT_FATAL( + topology == ttnn::ccl::Topology::Linear, + "This all_reduce API with cluster_axis is currently supported only for the Linear topology"); + const auto mesh_view = mesh_device.get_view(); + auto devices = input_tensor.get_workers(); + std::size_t num_devices = (cluster_axis == 0) ? mesh_view.num_rows() : mesh_view.num_cols(); + + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); + auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); + std::vector semaphores = multi_device_global_semaphore.global_semaphores; + + operation::launch_op( + [num_preferred_links, + memory_config, + mesh_view, + cluster_axis, + num_devices, + topology, + semaphores, + subdevice_id, + enable_persistent_fabric_mode]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + const auto& input_device_tensor = input_tensors.at(0); + + const auto coordinate = mesh_view.find_device(input_device_tensor.device()->id()); + std::vector devices = (cluster_axis == 0) ? mesh_view.get_devices_on_column(coordinate.col) + : mesh_view.get_devices_on_row(coordinate.row); + + const auto& input_tensor = input_tensors.at(0); + + return operation::run( + ttnn::ccl::all_reduce_detail::create_all_reduce_async_struct( + input_device_tensor, + num_preferred_links.has_value() ? num_preferred_links.value() : 1, + memory_config, + devices, + topology, + semaphores, + subdevice_id, + enable_persistent_fabric_mode), + {input_tensor}); + }, + {input_tensor}, + output_tensors); + return output_tensors.at(0); +} + +} // namespace ccl +} // namespace experimental +} // namespace operations + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp new file mode 100644 index 00000000000..05c89fa9b73 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include +#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/operations/ccl/ccl_op_fusion.hpp" +#include +#include "cpp/ttnn/global_semaphore.hpp" + +#include "ttnn/run_operation.hpp" + +#include +#include + +namespace ttnn { + +using ccl::EriscDatamoverBuilder; + +struct AllReduceAsync { + std::optional forward_device; + std::optional backward_device; + const uint32_t num_links; + const uint32_t ring_size; + const uint32_t ring_index; + const MemoryConfig output_mem_config; + const ccl::Topology topology; + const GlobalSemaphore semaphore; + std::optional sub_device_id; + bool enable_persistent_fabric_mode; + + AllReduceAsync( + std::optional forward_device, + std::optional backward_device, + uint32_t num_links, + uint32_t ring_size, + uint32_t ring_index, + MemoryConfig output_mem_config, + ccl::Topology topology, + GlobalSemaphore semaphore, + std::optional& sub_device_id, + bool enable_persistent_fabric_mode) : + forward_device(forward_device), + backward_device(backward_device), + num_links(num_links), + ring_size(ring_size), + ring_index(ring_index), + output_mem_config(output_mem_config), + topology(topology), + semaphore(semaphore), + sub_device_id(sub_device_id), + enable_persistent_fabric_mode(enable_persistent_fabric_mode) {} + + // Add attributes method for reflection + auto attributes() const { + using tt::stl::reflection::Attribute; + std::vector> attrs; + + attrs.emplace_back("num_links", num_links); + attrs.emplace_back("ring_size", ring_size); + attrs.emplace_back("ring_index", ring_index); + attrs.emplace_back("output_mem_config", output_mem_config); + attrs.emplace_back("topology", topology); + attrs.emplace_back("semaphore", semaphore); + + return attrs; + } + + void validate(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, std::vector& output_tensors) const; + const operation::Hash compute_program_hash(const std::vector& input_tensors) const; +}; + +namespace ccl { +namespace all_reduce_async_detail { +AllReduceAsync create_all_reduce_async_struct( + const Tensor& input_tensor, + const uint32_t num_links, + const std::optional& memory_config, + const std::vector& devices, + const ccl::Topology topology, + const std::vector& semaphores, + std::optional sub_device_id, + bool enable_persistent_fabric_mode); + +uint32_t find_scatter_dim(const ttnn::SimpleShape& input_tensor_padded_shape, size_t num_workers); +} // namespace all_reduce_async_detail +} // namespace ccl + +operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers( + const Tensor& input_tensor, + std::optional forward_device, + std::optional backward_device, + Tensor& output_tensor, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + ccl::Topology topology, + const GlobalSemaphore semaphore, + const std::optional& sub_device_id, + bool enable_persistent_fabric_mode); + +namespace operations { +namespace experimental { +namespace ccl { + +Tensor all_reduce_async( + const Tensor& input_tensor, + const uint32_t cluster_axis, + const MeshDevice& mesh_device, + const ttnn::ccl::Topology topology, + const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const std::optional& memory_config = std::nullopt, + const std::optional num_preferred_links = std::nullopt, + std::optional sub_device_id = std::nullopt, + bool enable_persistent_fabric_mode = false); + +} // namespace ccl +} // namespace experimental +} // namespace operations + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp new file mode 100644 index 00000000000..9942c3daafd --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -0,0 +1,333 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +/// +#include + +#include +#include +#include "ttnn/tensor/tensor_impl.hpp" +#include "all_reduce_async_op.hpp" +#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/operations/math.hpp" +#include +#include +#include +#include +#include "cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp" +#include "cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp" + +#include "cpp/ttnn/operations/ccl/common/uops/command_lowering.hpp" + +#include "cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" +#include "cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.hpp" +#include +#include +#include +#include +using namespace tt::constants; + +namespace ttnn { + +using namespace ccl; + +operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers( + const Tensor& input_tensor, + std::optional forward_device, + std::optional backward_device, + Tensor& output_tensor, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + ccl::Topology topology, + const GlobalSemaphore semaphore, + const std::optional& sub_device_id, + bool enable_persistent_fabric_mode) { + std::cout << "RUNNING NEW ALL REDUCE ASYNC" << std::endl; + std::cout << "num_links: " << num_links << std::endl; + std::cout << "ring_size: " << ring_size << std::endl; + std::cout << "ring_index: " << ring_index << std::endl; + + tt::tt_metal::Program program{}; + const bool enable_async_output_tensor = false; + TT_FATAL( + enable_persistent_fabric_mode, + "only persistent fabric mode is supported for all_gather_async_llama_post_binary_matmul"); + + IDevice* device = input_tensor.device(); + bool is_first_chip = ring_index == 0; + bool is_last_chip = ring_index == ring_size - 1; + log_trace( + tt::LogOp, + "DEBUG: device: {}, is_first_chip: {}, is_last_chip: {}", + input_tensor.device()->id(), + is_first_chip, + is_last_chip); + + std::optional local_fabric_handle = + ttnn::ccl::EdmLineFabricOpInterface::build_program_builder_worker_connection_fabric( + device, forward_device, backward_device, &program, enable_persistent_fabric_mode, num_links); + + // Get OP Config, topology config + std::vector input_tensors = {input_tensor}; + std::vector output_tensors = {output_tensor}; + const auto& op_config = ttnn::ccl::CCLOpConfig(input_tensors, output_tensors, topology); + LineTopology line_topology(ring_size, ring_index); + const size_t num_targets_forward = + line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::FORWARD); + const size_t num_targets_backward = + line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD); + + // Get worker cores, assuming 1 worker per link + uint32_t num_workers_per_link = 1; + const auto [sender_worker_core_range, sender_worker_cores] = + choose_worker_cores(num_links, num_workers_per_link, enable_persistent_fabric_mode, device); + + // Tensor Info + const auto input_tensor_num_pages = input_tensor.buffer()->num_pages(); + const auto input_tensor_cores = input_tensor.memory_config().shard_spec->grid; + const auto input_tensor_shard_shape = input_tensor.memory_config().shard_spec->shape; + const auto input_tensor_shard_num_pages = input_tensor_shard_shape[0] * input_tensor_shard_shape[1] / TILE_HW; + const auto output_tensor_cores = output_tensor.memory_config().shard_spec->grid; + const auto output_tensor_shard_shape = output_tensor.memory_config().shard_spec->shape; + const auto output_tensor_shard_num_pages = output_tensor_shard_shape[0] * output_tensor_shard_shape[1] / TILE_HW; + + tt::log_debug(tt::LogOp, "input_tensor_num_pages: {}", input_tensor_num_pages); + tt::log_debug(tt::LogOp, "input_tensor_cores: {}", input_tensor_cores); + tt::log_debug(tt::LogOp, "input_tensor_shard_shape: {}", input_tensor_shard_shape); + tt::log_debug(tt::LogOp, "input_tensor_shard_num_pages: {}", input_tensor_shard_num_pages); + tt::log_debug(tt::LogOp, "output_tensor_cores: {}", output_tensor_cores); + tt::log_debug(tt::LogOp, "output_tensor_shard_shape: {}", output_tensor_shard_shape); + tt::log_debug(tt::LogOp, "output_tensor_shard_num_pages: {}", output_tensor_shard_num_pages); + + // L1 Scratch CB Creation + const size_t packet_size_bytes = local_fabric_handle->get_edm_buffer_size_bytes(); + uint32_t l1_scratch_cb_page_size_bytes = op_config.get_page_size(); + uint32_t num_pages_per_packet = packet_size_bytes / l1_scratch_cb_page_size_bytes; + uint32_t cb_base_num_pages = std::lcm(input_tensor_shard_num_pages, output_tensor_shard_num_pages); + uint32_t cb_num_pages = std::lcm(num_pages_per_packet, cb_base_num_pages); + uint32_t src0_cb_index = tt::CB::c_in0; + tt::DataFormat df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(cb_num_pages * l1_scratch_cb_page_size_bytes, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, l1_scratch_cb_page_size_bytes); + CBHandle cb_src0_workers = CreateCircularBuffer(program, sender_worker_core_range, cb_src0_config); + // Set aside a buffer we can use for storing packet headers in (particularly for atomic incs) + const auto reserved_packet_header_CB_index = tt::CB::c_in6; + static constexpr auto num_packet_headers_storable = 8; + static constexpr auto packet_header_size_bytes = sizeof(tt::fabric::PacketHeader); + tt::tt_metal::CircularBufferConfig cb_reserved_packet_header_config = + tt::tt_metal::CircularBufferConfig( + num_packet_headers_storable * packet_header_size_bytes * 2, + {{reserved_packet_header_CB_index, tt::DataFormat::RawUInt32}}) + .set_page_size(reserved_packet_header_CB_index, packet_header_size_bytes); + auto reserved_packet_header_CB_handle = + CreateCircularBuffer(program, sender_worker_core_range, cb_reserved_packet_header_config); + + // KERNEL CREATION + // Reader + auto reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; + reader_kernel_config.compile_args = { + ring_index, // my_chip_id + src0_cb_index, // cb0_id + op_config.get_page_size(), // tensor0_page_size + }; + log_trace(tt::LogOp, "Reader Compile Args:"); + for (const auto& arg : reader_kernel_config.compile_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + auto worker_sender_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" + "llama_post_binary_matmul_shape_reader.cpp", + sender_worker_core_range, + reader_kernel_config); + + // Writer + auto writer_kernel_config = tt::tt_metal::WriterDataMovementConfig{}; + writer_kernel_config.compile_args = { + ring_index, // my_chip_id + reserved_packet_header_CB_index, // reserved_packet_header_cb_id + num_packet_headers_storable, // num_packet_headers_storable + src0_cb_index, // cb0_id + num_pages_per_packet, // packet_size_in_pages + op_config.get_page_size(), // tensor0_page_size + num_targets_forward, // num_targets_forward_direction + num_targets_backward, // num_targets_backward_direction + }; + log_trace(tt::LogOp, "Writer Compile Args:"); + for (const auto& arg : writer_kernel_config.compile_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + auto worker_sender_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" + "llama_post_binary_matmul_shape_writer.cpp", + sender_worker_core_range, + writer_kernel_config); + + // Kernel Runtime Args + CoreCoord drain_sync_core; // the first worker of each chip is the drain sync core, which contains the output ready + // semaphore + auto input_cores_vec = corerange_to_cores(input_tensor_cores, std::nullopt, true); + auto output_cores_vec = corerange_to_cores(output_tensor_cores, std::nullopt, true); + auto cores_per_device = output_cores_vec.size() / ring_size; + TT_FATAL( + output_cores_vec.size() % ring_size == 0, + "output sharded cores must be divisible by num_links for this work distribution scheme"); + auto output_cores_this_device = std::vector( + output_cores_vec.begin() + ring_index * cores_per_device, + output_cores_vec.begin() + (ring_index + 1) * cores_per_device); + + for (uint32_t link = 0; link < num_links; link++) { + CoreCoord core = sender_worker_cores[link]; + + // construct input and output core x and y + uint32_t base_pages_per_worker = input_tensor_num_pages / num_links; + uint32_t remainder = input_tensor_num_pages % num_links; + uint32_t input_tile_id_start = link * base_pages_per_worker + std::min(link, remainder); + uint32_t input_tile_id_end = (link + 1) * base_pages_per_worker + std::min(link + 1, remainder); + + uint32_t worker_num_tiles_to_read = input_tile_id_end - input_tile_id_start; + uint32_t input_first_core_tile_start_offset = worker_num_tiles_to_read % input_tensor_shard_num_pages; + uint32_t output_first_core_tile_start_offset = worker_num_tiles_to_read % output_tensor_shard_num_pages; + + std::vector input_tensor_cores_x; + std::vector input_tensor_cores_y; + std::vector output_tensor_cores_x; + std::vector output_tensor_cores_y; + for (uint32_t i = input_tile_id_start / input_tensor_shard_num_pages; + i < (input_tile_id_end + input_tensor_shard_num_pages - 1) / input_tensor_shard_num_pages; + i++) { + auto this_core = device->worker_core_from_logical_core(input_cores_vec[i]); + input_tensor_cores_x.push_back(this_core.x); + input_tensor_cores_y.push_back(this_core.y); + } + for (uint32_t i = input_tile_id_start / output_tensor_shard_num_pages; + i < (input_tile_id_end + output_tensor_shard_num_pages - 1) / output_tensor_shard_num_pages; + i++) { + auto this_core = device->worker_core_from_logical_core(output_cores_this_device[i]); + output_tensor_cores_x.push_back(this_core.x); + output_tensor_cores_y.push_back(this_core.y); + } + + tt::log_debug(tt::LogOp, "input_tile_id_start: {}", input_tile_id_start); + tt::log_debug(tt::LogOp, "input_tile_id_end: {}", input_tile_id_end); + tt::log_debug(tt::LogOp, "worker_num_tiles_to_read: {}", worker_num_tiles_to_read); + tt::log_debug(tt::LogOp, "input_first_core_tile_start_offset: {}", input_first_core_tile_start_offset); + tt::log_debug(tt::LogOp, "output_first_core_tile_start_offset: {}", output_first_core_tile_start_offset); + tt::log_debug(tt::LogOp, "input_tensor_cores_x: {}", input_tensor_cores_x); + tt::log_debug(tt::LogOp, "input_tensor_cores_y: {}", input_tensor_cores_y); + tt::log_debug(tt::LogOp, "output_tensor_cores_x: {}", output_tensor_cores_x); + tt::log_debug(tt::LogOp, "output_tensor_cores_y: {}", output_tensor_cores_y); + + if (link == 0) { + // drain sync core is the first worker core + drain_sync_core = device->worker_core_from_logical_core(core); + } + std::optional forward_fabric_connection = + line_topology.is_first_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) + ? std::nullopt + : std::optional(local_fabric_handle->uniquely_connect_worker( + device, ttnn::ccl::EdmLineFabricOpInterface::FORWARD)); + std::optional backward_fabric_connection = + line_topology.is_last_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) + ? std::nullopt + : std::optional(local_fabric_handle->uniquely_connect_worker( + device, ttnn::ccl::EdmLineFabricOpInterface::BACKWARD)); + + // Set reader runtime args + std::vector reader_rt_args = { + input_tensor.buffer()->address(), // tensor_address0 + input_tensor_shard_num_pages, // num_tiles_per_core + worker_num_tiles_to_read, // num_tiles_to_read + input_first_core_tile_start_offset, // first_core_tile_start_offset + input_tensor_cores_x.size(), // num_cores + }; + reader_rt_args.insert(reader_rt_args.end(), input_tensor_cores_x.begin(), input_tensor_cores_x.end()); + reader_rt_args.insert(reader_rt_args.end(), input_tensor_cores_y.begin(), input_tensor_cores_y.end()); + log_trace(tt::LogOp, "Reader Runtime Args:"); + for (const auto& arg : reader_rt_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + tt::tt_metal::SetRuntimeArgs(program, worker_sender_reader_kernel_id, {core}, reader_rt_args); + + // Set writer runtime args + bool wait_output_semaphore = (link == 0) && !enable_async_output_tensor; + bool reset_global_semaphore = (link == 0) && !enable_async_output_tensor; + uint32_t out_ready_sem_wait_value = ring_size * num_links; + std::vector writer_rt_args = { + output_tensor.buffer()->address(), // tensor_address0 + output_tensor_shard_num_pages, // num_tiles_per_core + worker_num_tiles_to_read, // num_tiles_to_read + output_first_core_tile_start_offset, // first_core_tile_start_offset + output_tensor_cores_x.size(), // num_cores + wait_output_semaphore, // wait_output_semaphore + reset_global_semaphore, // reset_global_semaphore + semaphore.address(), // out_ready_sem_bank_addr (absolute address) + drain_sync_core.x, // out_ready_sem_noc0_x + drain_sync_core.y, // out_ready_sem_noc0_y + out_ready_sem_wait_value, // out_ready_sem_wait_value + }; + writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_x.begin(), output_tensor_cores_x.end()); + writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_y.begin(), output_tensor_cores_y.end()); + log_trace(tt::LogOp, "Writer Runtime Args:"); + for (const auto& arg : writer_rt_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + writer_rt_args.push_back(forward_fabric_connection.has_value()); + if (forward_fabric_connection.has_value()) { + auto sender_worker_flow_control_semaphore_id = CreateSemaphore(program, {core}, 0); + auto sender_worker_teardown_semaphore_id = CreateSemaphore(program, {core}, 0); + auto sender_worker_buffer_index_semaphore_id = CreateSemaphore(program, {core}, 0); + append_worker_to_fabric_edm_sender_rt_args( + forward_fabric_connection.value(), + sender_worker_flow_control_semaphore_id, + sender_worker_teardown_semaphore_id, + sender_worker_buffer_index_semaphore_id, + writer_rt_args); + } + writer_rt_args.push_back(backward_fabric_connection.has_value()); + if (backward_fabric_connection.has_value()) { + auto sender_worker_flow_control_semaphore_id = CreateSemaphore(program, {core}, 0); + auto sender_worker_teardown_semaphore_id = CreateSemaphore(program, {core}, 0); + auto sender_worker_buffer_index_semaphore_id = CreateSemaphore(program, {core}, 0); + append_worker_to_fabric_edm_sender_rt_args( + backward_fabric_connection.value(), + sender_worker_flow_control_semaphore_id, + sender_worker_teardown_semaphore_id, + sender_worker_buffer_index_semaphore_id, + writer_rt_args); + } + tt::tt_metal::SetRuntimeArgs(program, worker_sender_writer_kernel_id, {core}, writer_rt_args); + } + + auto override_runtime_arguments_callback = + [worker_sender_reader_kernel_id, worker_sender_writer_kernel_id, semaphore, sender_worker_cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + const auto& input = input_tensors[0]; + const auto& output = output_tensors[0]; + + // update senders + auto& worker_reader_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_reader_kernel_id); + auto& worker_writer_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_writer_kernel_id); + for (const auto& core : sender_worker_cores) { + // reader + auto& worker_reader_sender_runtime_args = worker_reader_sender_runtime_args_by_core[core.x][core.y]; + worker_reader_sender_runtime_args[0] = input.buffer()->address(); + // writer + auto& worker_writer_sender_runtime_args = worker_writer_sender_runtime_args_by_core[core.x][core.y]; + worker_writer_sender_runtime_args[0] = output.buffer()->address(); + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +} // namespace ttnn From c2a11de3748d6526757ade7f174f51aed11357b1 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Mon, 3 Feb 2025 13:44:32 -0600 Subject: [PATCH 04/25] interleaved all gather works with good PCC. Next step: add reduction kernel (dataflow + compute) --- .../operations/ccl/test_new_all_reduce.py | 26 +- ..._reduce_async_program_minimal_variants.cpp | 31 +-- .../llama_post_binary_matmul_shape_reader.cpp | 99 +++++++ .../llama_post_binary_matmul_shape_writer.cpp | 253 ++++++++++++++++++ 4 files changed, 377 insertions(+), 32 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index daf71be854c..f849bc59f9d 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -69,10 +69,6 @@ def run_all_reduce_impl( logger.info(f"Output shape: {output_shape}") - input_tensor_mesh_list = [] - output_tensor_goldens_list = [] - tt_outs = [] - try: ################################## ##### Set up input tensors/configs @@ -124,6 +120,20 @@ def run_all_reduce_impl( mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(0, 1), mesh_shape=cluster_shape), ) + # output_tensor_goldens_list = [ + # torch.sum(input_tensor, dim=cluster_axis) + # for _ in range(num_iters) + # ] + + output_tensor_goldens_list = [ + input_tensor.transpose(cluster_axis, -2) + .reshape(8, M, 4, num_cores, N // num_cores) + .transpose(-3, -2) + .reshape(cluster_shape[0], M, num_cores, -1) + .reshape(cluster_shape[0], M, -1) + for _ in range(num_iters) + ] + ################################## ##### Run the op ################################## @@ -171,15 +181,15 @@ def run_op(): tt_out_tensor = tt_outs[tensor_index] output_tensor = output_tensor_goldens_list[tensor_index] for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): + output_tensor_ = output_tensor[i // cluster_shape[cluster_axis]].unsqueeze(0).unsqueeze(0) tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() logger.info(f"Checking for device {t.device().id()}") if input_dtype == ttnn.bfloat16: - eq, output = comp_equal(tt_output_tensor, output_tensor) + eq, output = comp_pcc(tt_output_tensor, output_tensor_) else: - eq, output = comp_pcc(tt_output_tensor, output_tensor) - if not eq: - logger.error(f"output mismatch for tensor {i}") + eq, output = comp_pcc(tt_output_tensor, output_tensor_) + logger.info(f"PCC output for {i} is: {output}") assert eq, f"{i} FAILED: {output}" finally: if enable_persistent_fabric and teardown_persistent_fabric: diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index 9942c3daafd..6f913a8b266 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -140,7 +140,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers } auto worker_sender_reader_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" "llama_post_binary_matmul_shape_reader.cpp", sender_worker_core_range, reader_kernel_config); @@ -163,7 +163,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers } auto worker_sender_writer_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" "llama_post_binary_matmul_shape_writer.cpp", sender_worker_core_range, writer_kernel_config); @@ -174,12 +174,6 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers auto input_cores_vec = corerange_to_cores(input_tensor_cores, std::nullopt, true); auto output_cores_vec = corerange_to_cores(output_tensor_cores, std::nullopt, true); auto cores_per_device = output_cores_vec.size() / ring_size; - TT_FATAL( - output_cores_vec.size() % ring_size == 0, - "output sharded cores must be divisible by num_links for this work distribution scheme"); - auto output_cores_this_device = std::vector( - output_cores_vec.begin() + ring_index * cores_per_device, - output_cores_vec.begin() + (ring_index + 1) * cores_per_device); for (uint32_t link = 0; link < num_links; link++) { CoreCoord core = sender_worker_cores[link]; @@ -192,12 +186,10 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers uint32_t worker_num_tiles_to_read = input_tile_id_end - input_tile_id_start; uint32_t input_first_core_tile_start_offset = worker_num_tiles_to_read % input_tensor_shard_num_pages; - uint32_t output_first_core_tile_start_offset = worker_num_tiles_to_read % output_tensor_shard_num_pages; + uint32_t output_first_core_tile_start_offset = 0; // worker_num_tiles_to_read % output_tensor_shard_num_pages; std::vector input_tensor_cores_x; std::vector input_tensor_cores_y; - std::vector output_tensor_cores_x; - std::vector output_tensor_cores_y; for (uint32_t i = input_tile_id_start / input_tensor_shard_num_pages; i < (input_tile_id_end + input_tensor_shard_num_pages - 1) / input_tensor_shard_num_pages; i++) { @@ -205,13 +197,6 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers input_tensor_cores_x.push_back(this_core.x); input_tensor_cores_y.push_back(this_core.y); } - for (uint32_t i = input_tile_id_start / output_tensor_shard_num_pages; - i < (input_tile_id_end + output_tensor_shard_num_pages - 1) / output_tensor_shard_num_pages; - i++) { - auto this_core = device->worker_core_from_logical_core(output_cores_this_device[i]); - output_tensor_cores_x.push_back(this_core.x); - output_tensor_cores_y.push_back(this_core.y); - } tt::log_debug(tt::LogOp, "input_tile_id_start: {}", input_tile_id_start); tt::log_debug(tt::LogOp, "input_tile_id_end: {}", input_tile_id_end); @@ -220,8 +205,6 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers tt::log_debug(tt::LogOp, "output_first_core_tile_start_offset: {}", output_first_core_tile_start_offset); tt::log_debug(tt::LogOp, "input_tensor_cores_x: {}", input_tensor_cores_x); tt::log_debug(tt::LogOp, "input_tensor_cores_y: {}", input_tensor_cores_y); - tt::log_debug(tt::LogOp, "output_tensor_cores_x: {}", output_tensor_cores_x); - tt::log_debug(tt::LogOp, "output_tensor_cores_y: {}", output_tensor_cores_y); if (link == 0) { // drain sync core is the first worker core @@ -260,10 +243,10 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers uint32_t out_ready_sem_wait_value = ring_size * num_links; std::vector writer_rt_args = { output_tensor.buffer()->address(), // tensor_address0 - output_tensor_shard_num_pages, // num_tiles_per_core + input_tensor_shard_num_pages, // num_tiles_per_core worker_num_tiles_to_read, // num_tiles_to_read output_first_core_tile_start_offset, // first_core_tile_start_offset - output_tensor_cores_x.size(), // num_cores + input_tensor_cores_x.size(), // num_cores wait_output_semaphore, // wait_output_semaphore reset_global_semaphore, // reset_global_semaphore semaphore.address(), // out_ready_sem_bank_addr (absolute address) @@ -271,8 +254,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers drain_sync_core.y, // out_ready_sem_noc0_y out_ready_sem_wait_value, // out_ready_sem_wait_value }; - writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_x.begin(), output_tensor_cores_x.end()); - writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_y.begin(), output_tensor_cores_y.end()); + writer_rt_args.insert(writer_rt_args.end(), input_tensor_cores_x.begin(), input_tensor_cores_x.end()); + writer_rt_args.insert(writer_rt_args.end(), input_tensor_cores_y.begin(), input_tensor_cores_y.end()); log_trace(tt::LogOp, "Writer Runtime Args:"); for (const auto& arg : writer_rt_args) { log_trace(tt::LogOp, "\t{}", arg); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp new file mode 100644 index 00000000000..4ac027b2ac7 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp @@ -0,0 +1,99 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +// NOTE: This should ideally be merged with `ccl_send_reader` when we are able to support compile time args +// that don't require macros to function + +#include "dataflow_api.h" +#include +#include "cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include +#include "cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" +#include "cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" + +#include "cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp" + +#include "cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" +#include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" + +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/io_descriptors.hpp" +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" +#include "cpp/ttnn/tensor/enum_types.hpp" +#include +#include + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr uint32_t my_chip_id = get_compile_time_arg_val(0); +constexpr uint32_t cb0_id = get_compile_time_arg_val(1); +constexpr uint32_t tensor0_page_size = get_compile_time_arg_val(2); + +/* + * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) + * dispatch implementations depending on those invocation parameters. + */ +void kernel_main() { + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + size_t arg_idx = 0; + // Load the input tensor spec + address_t tensor_address0 = get_arg_val(arg_idx++); + uint32_t num_tiles_per_core = get_arg_val(arg_idx++); + uint32_t num_tiles_to_read = get_arg_val(arg_idx++); + uint32_t first_core_tile_start_offset = get_arg_val(arg_idx++); + uint32_t num_cores = get_arg_val(arg_idx++); + tt_l1_ptr uint32_t* core_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + tt_l1_ptr uint32_t* core_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + + // print every compile and runtime arg in uint32_t + DPRINT << "ct args: \n"; + DPRINT << "my_chip_id: " << (uint32_t)my_chip_id << "\n"; + DPRINT << "cb0_id: " << (uint32_t)cb0_id << "\n"; + DPRINT << "tensor0_page_size: " << (uint32_t)tensor0_page_size << "\n"; + + DPRINT << "rt args: \n"; + DPRINT << "tensor_address0: " << (uint32_t)tensor_address0 << "\n"; + DPRINT << "num_tiles_per_core: " << (uint32_t)num_tiles_per_core << "\n"; + DPRINT << "num_tiles_to_read: " << (uint32_t)num_tiles_to_read << "\n"; + DPRINT << "first_core_tile_start_offset: " << (uint32_t)first_core_tile_start_offset << "\n"; + DPRINT << "num_cores: " << (uint32_t)num_cores << "\n"; + for (uint32_t i = 0; i < num_cores; i++) { + DPRINT << "core_noc_x[" << i << "]: " << (uint32_t)core_noc_x[i] << "\n"; + DPRINT << "core_noc_y[" << i << "]: " << (uint32_t)core_noc_y[i] << "\n"; + } + + // interleaved addrgen + + DPRINT << "tensor -> CB: " << (uint32_t)cb0_id << "\n"; + + uint32_t tiles_read = 0; + uint32_t shard_tile_id = first_core_tile_start_offset; + uint32_t core_id = 0; + while (tiles_read < num_tiles_to_read) { + DPRINT << "tiles_read: " << tiles_read << "\n"; + uint32_t num_tiles_to_read_this_core = + std::min(num_tiles_per_core - shard_tile_id, num_tiles_to_read - tiles_read); + cb_reserve_back(cb0_id, num_tiles_to_read_this_core); + const uint32_t l1_write_addr = get_write_ptr(cb0_id); + uint64_t read_addr = get_noc_addr(core_noc_x[core_id], core_noc_y[core_id], tensor_address0); + read_addr += shard_tile_id * tensor0_page_size; + + noc_async_read(read_addr, l1_write_addr, num_tiles_to_read_this_core * tensor0_page_size); + noc_async_read_barrier(); + + cb_push_back(cb0_id, num_tiles_to_read_this_core); + tiles_read += num_tiles_to_read_this_core; + shard_tile_id = 0; + core_id++; + } + + DPRINT << "DONE \n"; +} diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp new file mode 100644 index 00000000000..cfa40c68810 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp @@ -0,0 +1,253 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +// NOTE: This should ideally be merged with `ccl_send_reader` when we are able to support compile time args +// that don't require macros to function + +#include "dataflow_api.h" +#include +#include "cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include +#include "cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" +#include "cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" + +#include "cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp" + +#include "cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" +#include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" + +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/io_descriptors.hpp" +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" +#include "cpp/ttnn/tensor/enum_types.hpp" +#include +#include + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr uint32_t my_chip_id = get_compile_time_arg_val(0); +constexpr uint32_t reserved_packet_header_cb_id = get_compile_time_arg_val(1); +constexpr uint32_t num_packet_headers_storable = get_compile_time_arg_val(2); +constexpr uint32_t cb0_id = get_compile_time_arg_val(3); +constexpr uint32_t packet_size_in_pages = get_compile_time_arg_val(4); +constexpr uint32_t tensor0_page_size = get_compile_time_arg_val(5); +constexpr uint32_t num_targets_forward_direction = get_compile_time_arg_val(6); +constexpr uint32_t num_targets_backward_direction = get_compile_time_arg_val(7); + +FORCE_INLINE void write_and_advance_local_read_address_for_fabric_write( + uint64_t noc0_dest_noc_addr, + size_t packet_header_buffer_addr, + uint32_t num_targets_forward_direction, + uint32_t num_targets_backward_direction, + FabricConnectionManager& fabric_connection, + size_t& l1_read_addr, + uint32_t payload_size_bytes) { + const auto [dest_noc_xy, dest_addr] = get_noc_address_components(noc0_dest_noc_addr); + const size_t payload_l1_address = l1_read_addr; + + auto pkt_hdr = reinterpret_cast(packet_header_buffer_addr); +#ifdef DEBUG_PRINT_ENABLED + pkt_hdr->reserved2 = my_chip_id; +#endif + + size_t packet_send_size_bytes = payload_size_bytes + sizeof(tt::fabric::PacketHeader); + pkt_hdr->to_write()->to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, packet_send_size_bytes, static_cast(dest_noc_xy.x), static_cast(dest_noc_xy.y)}); + + noc_async_write(payload_l1_address, safe_get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr), payload_size_bytes); + if (fabric_connection.has_forward_connection()) { + pkt_hdr->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_forward_direction)}); + fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + fabric_connection.get_forward_connection().send_payload_without_header_non_blocking_from_address( + l1_read_addr, payload_size_bytes); + fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( + (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); + } + + if (fabric_connection.has_backward_connection()) { + pkt_hdr->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); + fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + fabric_connection.get_backward_connection().send_payload_without_header_non_blocking_from_address( + l1_read_addr, payload_size_bytes); + fabric_connection.get_backward_connection().send_payload_flush_blocking_from_address( + (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); + } + + l1_read_addr += payload_size_bytes; +} + +/* + * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) + * dispatch implementations depending on those invocation parameters. + */ +void kernel_main() { + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + size_t arg_idx = 0; + // Load the input tensor spec + address_t tensor_address0 = get_arg_val(arg_idx++); + uint32_t num_tiles_per_core = get_arg_val(arg_idx++); + uint32_t num_tiles_to_read = get_arg_val(arg_idx++); + uint32_t first_core_tile_start_offset = get_arg_val(arg_idx++); + uint32_t num_cores = get_arg_val(arg_idx++); + bool wait_output_semaphore = get_arg_val(arg_idx++); + bool reset_global_semaphore = get_arg_val(arg_idx++); + const size_t out_ready_sem_bank_addr = get_arg_val(arg_idx++); + const uint8_t out_ready_sem_noc0_x = get_arg_val(arg_idx++); + const uint8_t out_ready_sem_noc0_y = get_arg_val(arg_idx++); + uint32_t out_ready_sem_wait_value = get_arg_val(arg_idx++); + tt_l1_ptr uint32_t* core_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + tt_l1_ptr uint32_t* core_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + size_t arg_for_fab = arg_idx; + auto fabric_connection = FabricConnectionManager::build_from_args(arg_idx); + + DPRINT << "ct args: \n"; + DPRINT << "my_chip_id: " << (uint32_t)my_chip_id << "\n"; + DPRINT << "reserved_packet_header_cb_id: " << (uint32_t)reserved_packet_header_cb_id << "\n"; + DPRINT << "num_packet_headers_storable: " << (uint32_t)num_packet_headers_storable << "\n"; + DPRINT << "cb0_id: " << (uint32_t)cb0_id << "\n"; + DPRINT << "packet_size_in_pages: " << (uint32_t)packet_size_in_pages << "\n"; + DPRINT << "tensor0_page_size: " << (uint32_t)tensor0_page_size << "\n"; + DPRINT << "num_targets_forward_direction: " << (uint32_t)num_targets_forward_direction << "\n"; + DPRINT << "num_targets_backward_direction: " << (uint32_t)num_targets_backward_direction << "\n"; + + DPRINT << "rt args: \n"; + DPRINT << "tensor_address0: " << (uint32_t)tensor_address0 << "\n"; + DPRINT << "num_tiles_per_core: " << (uint32_t)num_tiles_per_core << "\n"; + DPRINT << "num_tiles_to_read: " << (uint32_t)num_tiles_to_read << "\n"; + DPRINT << "first_core_tile_start_offset: " << (uint32_t)first_core_tile_start_offset << "\n"; + DPRINT << "num_cores: " << (uint32_t)num_cores << "\n"; + for (uint32_t i = 0; i < num_cores; i++) { + DPRINT << "core_noc_x[" << i << "]: " << (uint32_t)core_noc_x[i] << "\n"; + DPRINT << "core_noc_y[" << i << "]: " << (uint32_t)core_noc_y[i] << "\n"; + } + DPRINT << "wait_output_semaphore: " << (uint32_t)wait_output_semaphore << "\n"; + DPRINT << "reset_global_semaphore: " << (uint32_t)reset_global_semaphore << "\n"; + DPRINT << "out_ready_sem_bank_addr: " << (uint32_t)out_ready_sem_bank_addr << "\n"; + DPRINT << "out_ready_sem_noc0_x: " << (uint32_t)out_ready_sem_noc0_x << "\n"; + DPRINT << "out_ready_sem_noc0_y: " << (uint32_t)out_ready_sem_noc0_y << "\n"; + DPRINT << "out_ready_sem_wait_value: " << (uint32_t)out_ready_sem_wait_value << "\n"; + + DPRINT << "arg_for_fab: " << (uint32_t)arg_for_fab << "\n"; + DPRINT << "fabric_connection arg 0" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 1" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 2" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 3" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 4" << get_arg_val(arg_for_fab++) << "\n"; + + // packet header cb + cb_reserve_back(reserved_packet_header_cb_id, num_packet_headers_storable); + auto packet_header_buffer_addr = get_write_ptr(reserved_packet_header_cb_id); + + if (fabric_connection.is_logically_connected()) { + fabric_connection.open(); + } + + // 1. mcast via fabric to remote tensor addresses + uint32_t tiles_read = 0; + uint32_t shard_tile_id = 0; // first_core_tile_start_offset; + uint32_t core_id = 0; + uint32_t writer_chip_offset = my_chip_id * num_tiles_per_core * tensor0_page_size; + + while (tiles_read < num_tiles_to_read) { + DPRINT << "tiles_read: " << tiles_read << "\n"; + uint32_t num_tiles_to_read_this_core = std::min(num_tiles_per_core - shard_tile_id, packet_size_in_pages); + num_tiles_to_read_this_core = 1; // std::min(num_tiles_to_read-tiles_read, num_tiles_to_read_this_core); + cb_wait_front(cb0_id, num_tiles_to_read_this_core); + size_t l1_read_addr = get_read_ptr(cb0_id); + + uint64_t noc0_dest_noc_addr = + get_noc_addr(core_noc_x[core_id], core_noc_y[core_id], tensor_address0 + writer_chip_offset, 0 /*noc_id*/); + + // Offset the writer chip offset + // noc0_dest_noc_addr += writer_chip_offset; + + DPRINT << "core_noc_x[core_id]: " << (uint32_t)core_noc_x[core_id] << "\n"; + DPRINT << "core_noc_y[core_id]: " << (uint32_t)core_noc_y[core_id] << "\n"; + DPRINT << "noc0_dest_noc_addr_base: " << noc0_dest_noc_addr << "\n"; + noc0_dest_noc_addr += shard_tile_id * tensor0_page_size; + + DPRINT << "core_id: " << core_id << "\n"; + DPRINT << "num_tiles_to_read_this_core: " << num_tiles_to_read_this_core << "\n"; + DPRINT << "noc0_dest_noc_addr: " << noc0_dest_noc_addr << "\n"; + DPRINT << "shard_tile_id: " << shard_tile_id << "\n"; + + write_and_advance_local_read_address_for_fabric_write( + noc0_dest_noc_addr, + packet_header_buffer_addr, + num_targets_forward_direction, + num_targets_backward_direction, + fabric_connection, + l1_read_addr, + num_tiles_to_read_this_core * tensor0_page_size); + noc_async_writes_flushed(); + + cb_pop_front(cb0_id, num_tiles_to_read_this_core); + tiles_read += num_tiles_to_read_this_core; + shard_tile_id += num_tiles_to_read_this_core; + if (shard_tile_id >= num_tiles_per_core) { + shard_tile_id = 0; + core_id++; + } + } + + // 2. mcast output ready semaphore + auto* pkt_hdr = reinterpret_cast(packet_header_buffer_addr); + pkt_hdr->to_atomic_inc(); + pkt_hdr->to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader{ + out_ready_sem_bank_addr, + static_cast(1), // increment 1 + 32, + static_cast(out_ready_sem_noc0_x), + static_cast(out_ready_sem_noc0_y)}); + // Write the mcast packet (forward) + if (fabric_connection.has_forward_connection()) { + fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + pkt_hdr->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_forward_direction)}); + fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( + packet_header_buffer_addr, sizeof(tt::fabric::PacketHeader)); + } + // Write the mcast packet (backward) + if (fabric_connection.has_backward_connection()) { + pkt_hdr->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); + fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + fabric_connection.get_backward_connection().send_payload_non_blocking_from_address( + packet_header_buffer_addr, sizeof(tt::fabric::PacketHeader)); + } + // increment locally + uint64_t out_ready_sem_noc_addr = + safe_get_noc_addr(out_ready_sem_noc0_x, out_ready_sem_noc0_y, out_ready_sem_bank_addr); + noc_semaphore_inc(out_ready_sem_noc_addr, 1); + DPRINT << "inc done\n"; + + // 3. wait for mcast output ready semaphore + if (wait_output_semaphore) { + while (*reinterpret_cast(out_ready_sem_bank_addr) < out_ready_sem_wait_value); + DPRINT << "waitval done\n"; + } + + // 4. global semaphore reset + if (reset_global_semaphore) { + const uint64_t dest_noc_addr = get_noc_addr(my_x[0], my_y[0], out_ready_sem_bank_addr); + noc_inline_dw_write(dest_noc_addr, 0); + DPRINT << "reset done\n"; + } + + if (fabric_connection.is_logically_connected()) { + fabric_connection.close(); + } + + noc_async_write_barrier(); + DPRINT << "DONE \n"; +} From 4b67c014225b6fb8d5a230b97e44b031d2771ebc Mon Sep 17 00:00:00 2001 From: avoraTT Date: Mon, 3 Feb 2025 14:56:04 -0600 Subject: [PATCH 05/25] wip reduction stuff. --- .../llama_post_binary_matmul_shape_writer.cpp | 5 +++ .../device/kernels/reduction_dataflow.cpp | 41 +++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp index cfa40c68810..e525af71e6f 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp @@ -237,6 +237,11 @@ void kernel_main() { DPRINT << "waitval done\n"; } + /* + reduction signal semaphore + mcast to reduction cores (using noc_semaphore_set_multicast) + */ + // 4. global semaphore reset if (reset_global_semaphore) { const uint64_t dest_noc_addr = get_noc_addr(my_x[0], my_y[0], out_ready_sem_bank_addr); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp new file mode 100644 index 00000000000..17c5c4b752d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +// NOTE: This should ideally be merged with `ccl_send_reader` when we are able to support compile time args +// that don't require macros to function + +#include "dataflow_api.h" + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +/* + * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) + * dispatch implementations depending on those invocation parameters. + */ +void kernel_main() { + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + /* + 1. Wait for signal (1) + 2. Clear signal + 4. Push back on interleaved all gather CB + 5. compute wait front on all gather cb + */ + + size_t arg_idx = 0; + const size_t wait_signal_sem_addr = get_arg_val(arg_idx++); + + volatile tt_l1_ptr uint32_t* l1_wait_signal_sem = + reinterpret_cast(wait_signal_sem_addr); + + // 1. Wait for signal + noc_semaphore_wait(l1_wait_signal_sem, 1); + noc_semaphoer_reset(l1_wait_signal_sem, 0); + + // 2. Synchronization +} From ce9e52d98ded9daf086ee3ae604f18b2d1c3864d Mon Sep 17 00:00:00 2001 From: kpaigwar Date: Tue, 4 Feb 2025 00:20:36 -0600 Subject: [PATCH 06/25] #0: added noc semaphore multicast in writer, seeing a hang on noc_sem_wait in reduction --- .../operations/ccl/test_new_all_reduce.py | 27 +++++++++++-------- ..._reduce_async_program_minimal_variants.cpp | 26 ++++++++++++++++++ .../llama_post_binary_matmul_shape_writer.cpp | 25 ++++++++++++++++- .../device/kernels/reduction_dataflow.cpp | 14 +++++----- 4 files changed, 72 insertions(+), 20 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index f849bc59f9d..668c2cea221 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -75,21 +75,26 @@ def run_all_reduce_impl( ################################## ##### FF2 Case ##### - num_cores = 24 # matmul ring + num_cores = 1 # 24 # matmul ring M, N = output_shape[2:] N_per_shard = round_up(math.ceil(N / num_cores), ttnn.TILE_SIZE) input_shape = [*cluster_shape, M, N] input_grid = num_cores_to_rectangle_grid(num_cores, mesh_device) CORE_RANGE = [(x, y) for y in range(input_grid[1]) for x in range(input_grid[0])] + # core_range_set = ttnn.CoreRangeSet( + # [ + # ttnn.CoreRange( + # ttnn.CoreCoord(x, y), + # ttnn.CoreCoord(x, y), + # ) + # for x, y in CORE_RANGE + # ] + # ) core_range_set = ttnn.CoreRangeSet( - [ - ttnn.CoreRange( - ttnn.CoreCoord(x, y), - ttnn.CoreCoord(x, y), - ) - for x, y in CORE_RANGE - ] + { + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(1, 0)), + } ) input_mem_config = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.WIDTH_SHARDED, @@ -201,14 +206,14 @@ def run_op(): @pytest.mark.parametrize( "output_shape, cluster_axis, num_links", [ - ([1, 1, 32, 3840], 1, 1), + ([1, 1, 32, 32], 1, 1), ], ) @pytest.mark.parametrize( "input_dtype", [ - # ttnn.bfloat16, - ttnn.bfloat8_b, + ttnn.bfloat16, + # ttnn.bfloat8_b, ], ) @pytest.mark.parametrize("num_iters", [1]) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index 6f913a8b266..cf2653d4e36 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -85,6 +85,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers const auto [sender_worker_core_range, sender_worker_cores] = choose_worker_cores(num_links, num_workers_per_link, enable_persistent_fabric_mode, device); + std::cout << "sender worker x, y: " << sender_worker_cores[0].x << ", " << sender_worker_cores[0].y << std::endl; + // Tensor Info const auto input_tensor_num_pages = input_tensor.buffer()->num_pages(); const auto input_tensor_cores = input_tensor.memory_config().shard_spec->grid; @@ -126,6 +128,10 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers auto reserved_packet_header_CB_handle = CreateCircularBuffer(program, sender_worker_core_range, cb_reserved_packet_header_config); + auto all_cores = input_tensor_cores.merge(sender_worker_core_range); + // Mcast args + auto writer_semaphore_id = tt::tt_metal::CreateSemaphore(program, all_cores, 0); + // KERNEL CREATION // Reader auto reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; @@ -156,6 +162,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers op_config.get_page_size(), // tensor0_page_size num_targets_forward, // num_targets_forward_direction num_targets_backward, // num_targets_backward_direction + writer_semaphore_id, // writer_semaphore_addr }; log_trace(tt::LogOp, "Writer Compile Args:"); for (const auto& arg : writer_kernel_config.compile_args) { @@ -168,6 +175,18 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers sender_worker_core_range, writer_kernel_config); + // Create reduction dataflow kernel + auto reduction_reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; + reduction_reader_kernel_config.compile_args = { + writer_semaphore_id, // writer_semaphore_addr + }; + auto reduction_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" + "reduction_dataflow.cpp", + input_tensor_cores, + reduction_reader_kernel_config); + // Kernel Runtime Args CoreCoord drain_sync_core; // the first worker of each chip is the drain sync core, which contains the output ready // semaphore @@ -238,6 +257,9 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers tt::tt_metal::SetRuntimeArgs(program, worker_sender_reader_kernel_id, {core}, reader_rt_args); // Set writer runtime args + CoreCoord mcast_start_core = input_tensor_cores.bounding_box().start_coord; + CoreCoord mcast_end_core = input_tensor_cores.bounding_box().end_coord; + bool wait_output_semaphore = (link == 0) && !enable_async_output_tensor; bool reset_global_semaphore = (link == 0) && !enable_async_output_tensor; uint32_t out_ready_sem_wait_value = ring_size * num_links; @@ -253,6 +275,10 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers drain_sync_core.x, // out_ready_sem_noc0_x drain_sync_core.y, // out_ready_sem_noc0_y out_ready_sem_wait_value, // out_ready_sem_wait_value + mcast_start_core.x, // mcast_dest_noc_start_x + mcast_start_core.y, // mcast_dest_noc_start_y + mcast_end_core.x, // mcast_dest_noc_end_x + mcast_end_core.y, // mcast_dest_noc_end_y }; writer_rt_args.insert(writer_rt_args.end(), input_tensor_cores_x.begin(), input_tensor_cores_x.end()); writer_rt_args.insert(writer_rt_args.end(), input_tensor_cores_y.begin(), input_tensor_cores_y.end()); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp index e525af71e6f..e695ae33ca9 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp @@ -89,6 +89,7 @@ void kernel_main() { /////////////////////////////////////////////////// // ARGS /////////////////////////////////////////////////// + uint32_t writer_semaphore_addr = get_semaphore(get_compile_time_arg_val(8)); size_t arg_idx = 0; // Load the input tensor spec @@ -103,6 +104,11 @@ void kernel_main() { const uint8_t out_ready_sem_noc0_x = get_arg_val(arg_idx++); const uint8_t out_ready_sem_noc0_y = get_arg_val(arg_idx++); uint32_t out_ready_sem_wait_value = get_arg_val(arg_idx++); + const uint32_t mcast_dest_noc_start_x = get_arg_val(arg_idx++); + const uint32_t mcast_dest_noc_start_y = get_arg_val(arg_idx++); + const uint32_t mcast_dest_noc_end_x = get_arg_val(arg_idx++); + const uint32_t mcast_dest_noc_end_y = get_arg_val(arg_idx++); + tt_l1_ptr uint32_t* core_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); arg_idx += num_cores; tt_l1_ptr uint32_t* core_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); @@ -236,11 +242,29 @@ void kernel_main() { while (*reinterpret_cast(out_ready_sem_bank_addr) < out_ready_sem_wait_value); DPRINT << "waitval done\n"; } + noc_async_write_barrier(); /* reduction signal semaphore mcast to reduction cores (using noc_semaphore_set_multicast) */ + const uint64_t writer_semaphore_noc_addr = get_noc_multicast_addr( + mcast_dest_noc_start_x, + mcast_dest_noc_start_y, + mcast_dest_noc_end_x, + mcast_dest_noc_end_y, + writer_semaphore_addr, + 0); + DPRINT << "mcast_dest_noc_start_x: " << mcast_dest_noc_start_x << "\n"; + DPRINT << "mcast_dest_noc_start_y: " << mcast_dest_noc_start_y << "\n"; + DPRINT << "mcast_dest_noc_end_x: " << mcast_dest_noc_end_x << "\n"; + DPRINT << "mcast_dest_noc_end_y: " << mcast_dest_noc_end_y << "\n"; + + // const uint64_t writer_semaphore_noc_addr = multicast_data_noc | writer_semaphore_addr; + volatile tt_l1_ptr uint32_t* writer_semaphore_addr_ptr = + reinterpret_cast(writer_semaphore_addr); + *writer_semaphore_addr_ptr = VALID; + noc_semaphore_set_multicast(writer_semaphore_addr, writer_semaphore_noc_addr, num_cores, false, false, 0); // 4. global semaphore reset if (reset_global_semaphore) { @@ -253,6 +277,5 @@ void kernel_main() { fabric_connection.close(); } - noc_async_write_barrier(); DPRINT << "DONE \n"; } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp index 17c5c4b752d..faeae2fbd31 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp @@ -26,16 +26,14 @@ void kernel_main() { 4. Push back on interleaved all gather CB 5. compute wait front on all gather cb */ + uint32_t writer_semaphore_addr = get_semaphore(get_compile_time_arg_val(0)); - size_t arg_idx = 0; - const size_t wait_signal_sem_addr = get_arg_val(arg_idx++); - - volatile tt_l1_ptr uint32_t* l1_wait_signal_sem = - reinterpret_cast(wait_signal_sem_addr); + volatile tt_l1_ptr uint32_t* writer_semaphore_addr_ptr = + reinterpret_cast(writer_semaphore_addr); + DPRINT << "Wait \n"; // 1. Wait for signal - noc_semaphore_wait(l1_wait_signal_sem, 1); - noc_semaphoer_reset(l1_wait_signal_sem, 0); - + noc_semaphore_wait(writer_semaphore_addr_ptr, VALID); + DPRINT << " Wait Over \n"; // 2. Synchronization } From bf9cb4701a6f86072153584476dcd1cd7f02065e Mon Sep 17 00:00:00 2001 From: avoraTT Date: Tue, 4 Feb 2025 07:10:25 -0600 Subject: [PATCH 07/25] Add fix for reduction worker hang. --- .../operations/ccl/test_new_all_reduce.py | 27 +++++------ ..._reduce_async_program_minimal_variants.cpp | 19 +++----- .../llama_post_binary_matmul_shape_writer.cpp | 46 +++++++++---------- .../device/kernels/reduction_dataflow.cpp | 19 ++------ 4 files changed, 43 insertions(+), 68 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 668c2cea221..7e57896bf3d 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -75,26 +75,21 @@ def run_all_reduce_impl( ################################## ##### FF2 Case ##### - num_cores = 1 # 24 # matmul ring + num_cores = 24 # matmul ring + core_offset = 1 M, N = output_shape[2:] N_per_shard = round_up(math.ceil(N / num_cores), ttnn.TILE_SIZE) input_shape = [*cluster_shape, M, N] - input_grid = num_cores_to_rectangle_grid(num_cores, mesh_device) - CORE_RANGE = [(x, y) for y in range(input_grid[1]) for x in range(input_grid[0])] - # core_range_set = ttnn.CoreRangeSet( - # [ - # ttnn.CoreRange( - # ttnn.CoreCoord(x, y), - # ttnn.CoreCoord(x, y), - # ) - # for x, y in CORE_RANGE - # ] - # ) + CORE_RANGE = [(x, y) for y in range(compute_grid_size.y) for x in range(compute_grid_size.x)] core_range_set = ttnn.CoreRangeSet( - { - ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(1, 0)), - } + [ + ttnn.CoreRange( + ttnn.CoreCoord(x, y), + ttnn.CoreCoord(x, y), + ) + for x, y in CORE_RANGE[core_offset : core_offset + num_cores] + ] ) input_mem_config = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.WIDTH_SHARDED, @@ -206,7 +201,7 @@ def run_op(): @pytest.mark.parametrize( "output_shape, cluster_axis, num_links", [ - ([1, 1, 32, 32], 1, 1), + ([1, 1, 32, 3840], 1, 1), ], ) @pytest.mark.parametrize( diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index cf2653d4e36..1ceb24acb08 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -45,11 +45,6 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers const GlobalSemaphore semaphore, const std::optional& sub_device_id, bool enable_persistent_fabric_mode) { - std::cout << "RUNNING NEW ALL REDUCE ASYNC" << std::endl; - std::cout << "num_links: " << num_links << std::endl; - std::cout << "ring_size: " << ring_size << std::endl; - std::cout << "ring_index: " << ring_index << std::endl; - tt::tt_metal::Program program{}; const bool enable_async_output_tensor = false; TT_FATAL( @@ -85,8 +80,6 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers const auto [sender_worker_core_range, sender_worker_cores] = choose_worker_cores(num_links, num_workers_per_link, enable_persistent_fabric_mode, device); - std::cout << "sender worker x, y: " << sender_worker_cores[0].x << ", " << sender_worker_cores[0].y << std::endl; - // Tensor Info const auto input_tensor_num_pages = input_tensor.buffer()->num_pages(); const auto input_tensor_cores = input_tensor.memory_config().shard_spec->grid; @@ -128,9 +121,9 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers auto reserved_packet_header_CB_handle = CreateCircularBuffer(program, sender_worker_core_range, cb_reserved_packet_header_config); + // Reduction kernel stuff auto all_cores = input_tensor_cores.merge(sender_worker_core_range); - // Mcast args - auto writer_semaphore_id = tt::tt_metal::CreateSemaphore(program, all_cores, 0); + auto reduction_semaphore_id = tt::tt_metal::CreateSemaphore(program, all_cores, 0); // KERNEL CREATION // Reader @@ -162,7 +155,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers op_config.get_page_size(), // tensor0_page_size num_targets_forward, // num_targets_forward_direction num_targets_backward, // num_targets_backward_direction - writer_semaphore_id, // writer_semaphore_addr + reduction_semaphore_id, // reduction_semaphore_send_addr }; log_trace(tt::LogOp, "Writer Compile Args:"); for (const auto& arg : writer_kernel_config.compile_args) { @@ -178,7 +171,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers // Create reduction dataflow kernel auto reduction_reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; reduction_reader_kernel_config.compile_args = { - writer_semaphore_id, // writer_semaphore_addr + reduction_semaphore_id, // signal_semaphore_addr }; auto reduction_reader_kernel_id = tt::tt_metal::CreateKernel( program, @@ -257,8 +250,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers tt::tt_metal::SetRuntimeArgs(program, worker_sender_reader_kernel_id, {core}, reader_rt_args); // Set writer runtime args - CoreCoord mcast_start_core = input_tensor_cores.bounding_box().start_coord; - CoreCoord mcast_end_core = input_tensor_cores.bounding_box().end_coord; + auto mcast_start_core = device->worker_core_from_logical_core(input_tensor_cores.bounding_box().start_coord); + auto mcast_end_core = device->worker_core_from_logical_core(input_tensor_cores.bounding_box().end_coord); bool wait_output_semaphore = (link == 0) && !enable_async_output_tensor; bool reset_global_semaphore = (link == 0) && !enable_async_output_tensor; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp index e695ae33ca9..7228da10bca 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp @@ -89,7 +89,7 @@ void kernel_main() { /////////////////////////////////////////////////// // ARGS /////////////////////////////////////////////////// - uint32_t writer_semaphore_addr = get_semaphore(get_compile_time_arg_val(8)); + uint32_t reduction_semaphore_send_addr = get_semaphore(get_compile_time_arg_val(8)); size_t arg_idx = 0; // Load the input tensor spec @@ -109,6 +109,10 @@ void kernel_main() { const uint32_t mcast_dest_noc_end_x = get_arg_val(arg_idx++); const uint32_t mcast_dest_noc_end_y = get_arg_val(arg_idx++); + volatile tt_l1_ptr uint32_t* reduction_semaphore_send_addr_ptr = + reinterpret_cast(reduction_semaphore_send_addr); + noc_semaphore_set(reduction_semaphore_send_addr_ptr, VALID); + tt_l1_ptr uint32_t* core_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); arg_idx += num_cores; tt_l1_ptr uint32_t* core_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); @@ -241,30 +245,23 @@ void kernel_main() { if (wait_output_semaphore) { while (*reinterpret_cast(out_ready_sem_bank_addr) < out_ready_sem_wait_value); DPRINT << "waitval done\n"; - } - noc_async_write_barrier(); - /* - reduction signal semaphore - mcast to reduction cores (using noc_semaphore_set_multicast) - */ - const uint64_t writer_semaphore_noc_addr = get_noc_multicast_addr( - mcast_dest_noc_start_x, - mcast_dest_noc_start_y, - mcast_dest_noc_end_x, - mcast_dest_noc_end_y, - writer_semaphore_addr, - 0); - DPRINT << "mcast_dest_noc_start_x: " << mcast_dest_noc_start_x << "\n"; - DPRINT << "mcast_dest_noc_start_y: " << mcast_dest_noc_start_y << "\n"; - DPRINT << "mcast_dest_noc_end_x: " << mcast_dest_noc_end_x << "\n"; - DPRINT << "mcast_dest_noc_end_y: " << mcast_dest_noc_end_y << "\n"; - - // const uint64_t writer_semaphore_noc_addr = multicast_data_noc | writer_semaphore_addr; - volatile tt_l1_ptr uint32_t* writer_semaphore_addr_ptr = - reinterpret_cast(writer_semaphore_addr); - *writer_semaphore_addr_ptr = VALID; - noc_semaphore_set_multicast(writer_semaphore_addr, writer_semaphore_noc_addr, num_cores, false, false, 0); + // Signal the reduction workers + const uint64_t reduction_semaphore_recv_noc_addr = get_noc_multicast_addr( + mcast_dest_noc_start_x, + mcast_dest_noc_start_y, + mcast_dest_noc_end_x, + mcast_dest_noc_end_y, + reduction_semaphore_send_addr); + + noc_semaphore_set_multicast( + reduction_semaphore_send_addr, + reduction_semaphore_recv_noc_addr, + num_cores, + false, + false, // TODO: Why? + 0); + } // 4. global semaphore reset if (reset_global_semaphore) { @@ -277,5 +274,6 @@ void kernel_main() { fabric_connection.close(); } + noc_async_write_barrier(); DPRINT << "DONE \n"; } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp index faeae2fbd31..7b2cca4b788 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp @@ -2,19 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 -// NOTE: This should ideally be merged with `ccl_send_reader` when we are able to support compile time args -// that don't require macros to function - #include "dataflow_api.h" -/////////////////////////////////////////////////// -// COMPILE TIME ARGS -/////////////////////////////////////////////////// - -/* - * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) - * dispatch implementations depending on those invocation parameters. - */ void kernel_main() { /////////////////////////////////////////////////// // ARGS @@ -26,14 +15,14 @@ void kernel_main() { 4. Push back on interleaved all gather CB 5. compute wait front on all gather cb */ - uint32_t writer_semaphore_addr = get_semaphore(get_compile_time_arg_val(0)); + uint32_t signal_semaphore_addr = get_semaphore(get_compile_time_arg_val(0)); - volatile tt_l1_ptr uint32_t* writer_semaphore_addr_ptr = - reinterpret_cast(writer_semaphore_addr); + volatile tt_l1_ptr uint32_t* signal_semaphore_addr_ptr = + reinterpret_cast(signal_semaphore_addr); DPRINT << "Wait \n"; // 1. Wait for signal - noc_semaphore_wait(writer_semaphore_addr_ptr, VALID); + noc_semaphore_wait(signal_semaphore_addr_ptr, VALID); DPRINT << " Wait Over \n"; // 2. Synchronization } From 600352e93d9fc1663382f23c60b0fb8ef15b0743 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Tue, 4 Feb 2025 08:06:05 -0600 Subject: [PATCH 08/25] Add reduction and output cb. Currently, the reduction kernel does a copy into the output cb (temporary). Next: add reduction compute kernel. --- .../operations/ccl/test_new_all_reduce.py | 13 +++-- ..._reduce_async_program_minimal_variants.cpp | 54 ++++++++++++++----- .../llama_post_binary_matmul_shape_writer.cpp | 4 +- .../device/kernels/reduction_dataflow.cpp | 33 +++++++----- 4 files changed, 73 insertions(+), 31 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 7e57896bf3d..aec3608dcac 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -120,19 +120,21 @@ def run_all_reduce_impl( mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(0, 1), mesh_shape=cluster_shape), ) + # All-Reduce Golden # output_tensor_goldens_list = [ # torch.sum(input_tensor, dim=cluster_axis) # for _ in range(num_iters) # ] - output_tensor_goldens_list = [ - input_tensor.transpose(cluster_axis, -2) - .reshape(8, M, 4, num_cores, N // num_cores) + # Interleaved All-Gather Golden + output_tensor_golden = input_tensor.transpose(cluster_axis, -2) + output_tensor_golden = ( + output_tensor_golden.reshape(*output_tensor_golden.shape[:3], num_cores, N // num_cores) .transpose(-3, -2) .reshape(cluster_shape[0], M, num_cores, -1) .reshape(cluster_shape[0], M, -1) - for _ in range(num_iters) - ] + ) + output_tensor_goldens_list = [output_tensor_golden for _ in range(num_iters)] ################################## ##### Run the op @@ -181,6 +183,7 @@ def run_op(): tt_out_tensor = tt_outs[tensor_index] output_tensor = output_tensor_goldens_list[tensor_index] for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): + # get_device_tensors returns row major, so we need to select the correct golden tensor output_tensor_ = output_tensor[i // cluster_shape[cluster_axis]].unsqueeze(0).unsqueeze(0) tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() logger.info(f"Checking for device {t.device().id()}") diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index 1ceb24acb08..5fba232c6ff 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -125,6 +125,46 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers auto all_cores = input_tensor_cores.merge(sender_worker_core_range); auto reduction_semaphore_id = tt::tt_metal::CreateSemaphore(program, all_cores, 0); + /* reduction cb */ + uint32_t reduction_CB_single_tile_size = input_tensor.get_tensor_spec().tile().get_tile_size(df); + uint32_t reduction_CB_tiles = input_tensor_num_pages / input_tensor_cores.num_cores() * ring_size; + uint32_t reduction_CB_size = reduction_CB_tiles * reduction_CB_single_tile_size; + + uint32_t reduction_cb_index = tt::CBIndex::c_1; + tt::tt_metal::CircularBufferConfig reduction_cb_config = + tt::tt_metal::CircularBufferConfig(reduction_CB_size, {{reduction_cb_index, df}}) + .set_page_size(reduction_cb_index, reduction_CB_single_tile_size); + // .set_globally_allocated_address(*output_tensor.buffer()); // TODO: Remove once new cb attached for output + auto cb_reduction = tt::tt_metal::CreateCircularBuffer(program, all_cores, reduction_cb_config); + + /* out cb */ + uint32_t out_CB_single_tile_size = input_tensor.get_tensor_spec().tile().get_tile_size(df); + uint32_t out_CB_tiles = input_tensor_num_pages / input_tensor_cores.num_cores(); + uint32_t out_CB_size = out_CB_tiles * out_CB_single_tile_size; + + uint32_t out_cb_index = tt::CBIndex::c_2; + tt::tt_metal::CircularBufferConfig out_cb_config = + tt::tt_metal::CircularBufferConfig(out_CB_size, {{out_cb_index, df}}) + .set_page_size(out_cb_index, out_CB_single_tile_size) + .set_globally_allocated_address(*output_tensor.buffer()); // TODO: Remove once new cb attached for output + auto cb_out = tt::tt_metal::CreateCircularBuffer( + program, input_tensor_cores, out_cb_config); // TODO: This should be the output cores instead + + // Create reduction dataflow kernel + auto reduction_reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; + reduction_reader_kernel_config.compile_args = { + reduction_cb_index, // reduction_cb_index + reduction_CB_tiles, // total_num_reduction_tiles + reduction_semaphore_id, // signal_semaphore_addr + out_cb_index, // out_cb_index + }; + auto reduction_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" + "reduction_dataflow.cpp", + input_tensor_cores, + reduction_reader_kernel_config); + // KERNEL CREATION // Reader auto reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; @@ -168,18 +208,6 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers sender_worker_core_range, writer_kernel_config); - // Create reduction dataflow kernel - auto reduction_reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; - reduction_reader_kernel_config.compile_args = { - reduction_semaphore_id, // signal_semaphore_addr - }; - auto reduction_reader_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" - "reduction_dataflow.cpp", - input_tensor_cores, - reduction_reader_kernel_config); - // Kernel Runtime Args CoreCoord drain_sync_core; // the first worker of each chip is the drain sync core, which contains the output ready // semaphore @@ -257,7 +285,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers bool reset_global_semaphore = (link == 0) && !enable_async_output_tensor; uint32_t out_ready_sem_wait_value = ring_size * num_links; std::vector writer_rt_args = { - output_tensor.buffer()->address(), // tensor_address0 + reduction_cb_index, // tensor_address0 input_tensor_shard_num_pages, // num_tiles_per_core worker_num_tiles_to_read, // num_tiles_to_read output_first_core_tile_start_offset, // first_core_tile_start_offset diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp index 7228da10bca..bca5b6aeb8e 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp @@ -93,7 +93,9 @@ void kernel_main() { size_t arg_idx = 0; // Load the input tensor spec - address_t tensor_address0 = get_arg_val(arg_idx++); + uint32_t reduction_output_cb_id = get_arg_val(arg_idx++); + address_t tensor_address0 = get_write_ptr(reduction_output_cb_id); + uint32_t num_tiles_per_core = get_arg_val(arg_idx++); uint32_t num_tiles_to_read = get_arg_val(arg_idx++); uint32_t first_core_tile_start_offset = get_arg_val(arg_idx++); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp index 7b2cca4b788..2a9d0b37eb8 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp @@ -8,21 +8,30 @@ void kernel_main() { /////////////////////////////////////////////////// // ARGS /////////////////////////////////////////////////// - - /* - 1. Wait for signal (1) - 2. Clear signal - 4. Push back on interleaved all gather CB - 5. compute wait front on all gather cb - */ - uint32_t signal_semaphore_addr = get_semaphore(get_compile_time_arg_val(0)); + constexpr uint32_t cb_id = get_compile_time_arg_val(0); + constexpr uint32_t total_num_reduction_tiles = get_compile_time_arg_val(1); + const uint32_t signal_semaphore_addr = get_semaphore(get_compile_time_arg_val(2)); + constexpr uint32_t out_cb_id = get_compile_time_arg_val(3); volatile tt_l1_ptr uint32_t* signal_semaphore_addr_ptr = reinterpret_cast(signal_semaphore_addr); - DPRINT << "Wait \n"; - // 1. Wait for signal + // 1. Wait for signal from All-Gather worker noc_semaphore_wait(signal_semaphore_addr_ptr, VALID); - DPRINT << " Wait Over \n"; - // 2. Synchronization + noc_semaphore_set(signal_semaphore_addr_ptr, 0); + + // 2. Signal compute kernel to start processing + cb_push_back(cb_id, total_num_reduction_tiles); + + // Temp copy from reduction to output + uint32_t l1_write_addr = get_write_ptr(out_cb_id); + uint64_t l1_read_addr = get_noc_addr(get_read_ptr(cb_id)); + uint32_t tile_size = get_tile_size(cb_id); + + for (uint32_t i = 0; i < total_num_reduction_tiles; i++) { + noc_async_read(l1_read_addr, l1_write_addr, tile_size); + l1_read_addr += (uint64_t)tile_size; + l1_write_addr += tile_size; + } + noc_async_read_barrier(); } From 5d4d879678e45843abc35008ff98cd7681bf1ab7 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Tue, 4 Feb 2025 11:50:13 -0600 Subject: [PATCH 09/25] All-reduce for FF1/FF3 works --- .../operations/ccl/test_new_all_reduce.py | 29 ++++----- .../device/all_reduce_async_op.cpp | 1 - ..._reduce_async_program_minimal_variants.cpp | 23 ++++++- .../device/kernels/eltwise_binary_kernel.cpp | 60 +++++++++++++++++++ .../device/kernels/reduction_dataflow.cpp | 13 ---- 5 files changed, 93 insertions(+), 33 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/eltwise_binary_kernel.cpp diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index aec3608dcac..548903599dd 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -105,7 +105,7 @@ def run_all_reduce_impl( ttnn.BufferType.L1, ttnn.ShardSpec( core_range_set, - [M, N_per_shard * 4], + [M, N_per_shard], ttnn.ShardOrientation.ROW_MAJOR, ), ) @@ -121,20 +121,17 @@ def run_all_reduce_impl( ) # All-Reduce Golden - # output_tensor_goldens_list = [ - # torch.sum(input_tensor, dim=cluster_axis) - # for _ in range(num_iters) - # ] - - # Interleaved All-Gather Golden - output_tensor_golden = input_tensor.transpose(cluster_axis, -2) - output_tensor_golden = ( - output_tensor_golden.reshape(*output_tensor_golden.shape[:3], num_cores, N // num_cores) - .transpose(-3, -2) - .reshape(cluster_shape[0], M, num_cores, -1) - .reshape(cluster_shape[0], M, -1) - ) - output_tensor_goldens_list = [output_tensor_golden for _ in range(num_iters)] + output_tensor_goldens_list = [torch.sum(input_tensor, dim=cluster_axis) for _ in range(num_iters)] + + # # Interleaved All-Gather Golden + # output_tensor_golden = input_tensor.transpose(cluster_axis, -2) + # output_tensor_golden = ( + # output_tensor_golden.reshape(*output_tensor_golden.shape[:3], num_cores, N // num_cores) + # .transpose(-3, -2) + # .reshape(cluster_shape[0], M, num_cores, -1) + # .reshape(cluster_shape[0], M, -1) + # ) + # output_tensor_goldens_list = [output_tensor_golden for _ in range(num_iters)] ################################## ##### Run the op @@ -211,7 +208,7 @@ def run_op(): "input_dtype", [ ttnn.bfloat16, - # ttnn.bfloat8_b, + ttnn.bfloat8_b, ], ) @pytest.mark.parametrize("num_iters", [1]) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp index 06e594066d9..9bfaed5160f 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp @@ -123,7 +123,6 @@ static void validate_output_tensor_allocation(const std::vector& output_ std::vector AllReduceAsync::compute_output_specs(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors[0]; auto shape = input_tensor.get_padded_shape(); // TODO: Replace with get_logical_shape() - shape[3] *= this->ring_size; return {TensorSpec( shape, TensorLayout(input_tensor.get_dtype(), input_tensor.get_tensor_spec().page_config(), output_mem_config))}; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index 5fba232c6ff..583214243e5 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -127,7 +127,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers /* reduction cb */ uint32_t reduction_CB_single_tile_size = input_tensor.get_tensor_spec().tile().get_tile_size(df); - uint32_t reduction_CB_tiles = input_tensor_num_pages / input_tensor_cores.num_cores() * ring_size; + uint32_t reduction_CB_tiles = input_tensor_shard_num_pages * ring_size; uint32_t reduction_CB_size = reduction_CB_tiles * reduction_CB_single_tile_size; uint32_t reduction_cb_index = tt::CBIndex::c_1; @@ -139,7 +139,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers /* out cb */ uint32_t out_CB_single_tile_size = input_tensor.get_tensor_spec().tile().get_tile_size(df); - uint32_t out_CB_tiles = input_tensor_num_pages / input_tensor_cores.num_cores(); + uint32_t out_CB_tiles = input_tensor_shard_num_pages; uint32_t out_CB_size = out_CB_tiles * out_CB_single_tile_size; uint32_t out_cb_index = tt::CBIndex::c_2; @@ -156,7 +156,6 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers reduction_cb_index, // reduction_cb_index reduction_CB_tiles, // total_num_reduction_tiles reduction_semaphore_id, // signal_semaphore_addr - out_cb_index, // out_cb_index }; auto reduction_reader_kernel_id = tt::tt_metal::CreateKernel( program, @@ -165,6 +164,24 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers input_tensor_cores, reduction_reader_kernel_config); + // Create reduction dataflow kernel + auto reduction_kernel_config = tt::tt_metal::ComputeConfig{}; + reduction_kernel_config.compile_args = { + reduction_cb_index, // reduction_cb_index + out_cb_index, // out_cb_index + }; + auto reduction_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" + "eltwise_binary_kernel.cpp", + input_tensor_cores, + reduction_kernel_config); + std::vector reduction_kernel_rt_args = { + ring_size, // num_blocks + input_tensor_shard_num_pages, // block_num_tiles + }; + tt::tt_metal::SetRuntimeArgs(program, reduction_kernel_id, input_tensor_cores, reduction_kernel_rt_args); + // KERNEL CREATION // Reader auto reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/eltwise_binary_kernel.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/eltwise_binary_kernel.cpp new file mode 100644 index 00000000000..85b85140ed8 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/eltwise_binary_kernel.cpp @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/tile_move_copy.h" + +namespace NAMESPACE { +void MAIN { + constexpr uint32_t cb_in0 = get_compile_time_arg_val(0); + constexpr uint32_t cb_out0 = get_compile_time_arg_val(1); + constexpr uint32_t cb_in1 = cb_in0; + + uint32_t rt_args_idx = 0; + const uint32_t num_blocks = get_arg_val(rt_args_idx++); + const uint32_t block_num_tiles = get_arg_val(rt_args_idx++); + const uint32_t copy_first_block = num_blocks % 2 != 0; + + constexpr uint32_t max_dst_tiles = 8; + + cb_wait_front(cb_in0, num_blocks * block_num_tiles); + cb_reserve_back(cb_out0, block_num_tiles); + + binary_op_init_common(cb_in0, cb_in1, cb_out0); + add_tiles_init(cb_in0, cb_in1, true); + + uint32_t num_pack_iters = (block_num_tiles + max_dst_tiles - 1) / max_dst_tiles; + uint32_t block_num_tiles_cnt = 0; + + for (uint32_t p = 0; p < num_pack_iters; ++p) { + uint32_t num_tiles_to_pack = std::min(max_dst_tiles, block_num_tiles - block_num_tiles_cnt); + tile_regs_acquire(); + for (uint32_t block = 0; block < num_blocks; block += 2) { + if (copy_first_block && block == 0) { + // TODO: Future support + } else { + for (uint32_t i = 0; i < num_tiles_to_pack; ++i) { + add_tiles( + cb_in0, + cb_in1, + block * block_num_tiles + p * max_dst_tiles + i, + (block + 1) * block_num_tiles + p * max_dst_tiles + i, + i); + } + } + } + tile_regs_commit(); + + // Pack output tiles + tile_regs_wait(); + for (uint32_t i = 0; i < num_tiles_to_pack; ++i) { + pack_tile(i, cb_out0, p * max_dst_tiles + i); + } + tile_regs_release(); + + block_num_tiles_cnt += num_tiles_to_pack; + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp index 2a9d0b37eb8..116fd60cc1f 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp @@ -11,7 +11,6 @@ void kernel_main() { constexpr uint32_t cb_id = get_compile_time_arg_val(0); constexpr uint32_t total_num_reduction_tiles = get_compile_time_arg_val(1); const uint32_t signal_semaphore_addr = get_semaphore(get_compile_time_arg_val(2)); - constexpr uint32_t out_cb_id = get_compile_time_arg_val(3); volatile tt_l1_ptr uint32_t* signal_semaphore_addr_ptr = reinterpret_cast(signal_semaphore_addr); @@ -22,16 +21,4 @@ void kernel_main() { // 2. Signal compute kernel to start processing cb_push_back(cb_id, total_num_reduction_tiles); - - // Temp copy from reduction to output - uint32_t l1_write_addr = get_write_ptr(out_cb_id); - uint64_t l1_read_addr = get_noc_addr(get_read_ptr(cb_id)); - uint32_t tile_size = get_tile_size(cb_id); - - for (uint32_t i = 0; i < total_num_reduction_tiles; i++) { - noc_async_read(l1_read_addr, l1_write_addr, tile_size); - l1_read_addr += (uint64_t)tile_size; - l1_write_addr += tile_size; - } - noc_async_read_barrier(); } From 27414ef85ca1c6eff1784a2de61458c5e70a1606 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Tue, 4 Feb 2025 14:31:00 -0600 Subject: [PATCH 10/25] Add support for reshard. TODO: add support to drop padding from input tensor. --- .../operations/ccl/test_new_all_reduce.py | 48 ++++++++++++------- ..._reduce_async_program_minimal_variants.cpp | 41 +++++++++------- 2 files changed, 54 insertions(+), 35 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 548903599dd..49b457c1ec1 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -27,6 +27,8 @@ def run_all_reduce_impl( cluster_axis, input_dtype, num_links, + input_num_cores, + output_num_cores, num_iters=1, enable_async=False, trace_mode=False, @@ -75,10 +77,10 @@ def run_all_reduce_impl( ################################## ##### FF2 Case ##### - num_cores = 24 # matmul ring core_offset = 1 M, N = output_shape[2:] - N_per_shard = round_up(math.ceil(N / num_cores), ttnn.TILE_SIZE) + N_per_shard = round_up(math.ceil(N / input_num_cores), ttnn.TILE_SIZE) + output_N_per_shard = round_up(math.ceil(N / output_num_cores), ttnn.TILE_SIZE) input_shape = [*cluster_shape, M, N] CORE_RANGE = [(x, y) for y in range(compute_grid_size.y) for x in range(compute_grid_size.x)] @@ -88,7 +90,7 @@ def run_all_reduce_impl( ttnn.CoreCoord(x, y), ttnn.CoreCoord(x, y), ) - for x, y in CORE_RANGE[core_offset : core_offset + num_cores] + for x, y in CORE_RANGE[core_offset : core_offset + input_num_cores] ] ) input_mem_config = ttnn.MemoryConfig( @@ -100,12 +102,21 @@ def run_all_reduce_impl( ttnn.ShardOrientation.ROW_MAJOR, ), ) + output_core_range_set = ttnn.CoreRangeSet( + [ + ttnn.CoreRange( + ttnn.CoreCoord(x, y), + ttnn.CoreCoord(x, y), + ) + for x, y in CORE_RANGE[core_offset : core_offset + output_num_cores] + ] + ) output_mem_config = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1, ttnn.ShardSpec( - core_range_set, - [M, N_per_shard], + output_core_range_set, + [M, output_N_per_shard], ttnn.ShardOrientation.ROW_MAJOR, ), ) @@ -123,16 +134,6 @@ def run_all_reduce_impl( # All-Reduce Golden output_tensor_goldens_list = [torch.sum(input_tensor, dim=cluster_axis) for _ in range(num_iters)] - # # Interleaved All-Gather Golden - # output_tensor_golden = input_tensor.transpose(cluster_axis, -2) - # output_tensor_golden = ( - # output_tensor_golden.reshape(*output_tensor_golden.shape[:3], num_cores, N // num_cores) - # .transpose(-3, -2) - # .reshape(cluster_shape[0], M, num_cores, -1) - # .reshape(cluster_shape[0], M, -1) - # ) - # output_tensor_goldens_list = [output_tensor_golden for _ in range(num_iters)] - ################################## ##### Run the op ################################## @@ -181,7 +182,11 @@ def run_op(): output_tensor = output_tensor_goldens_list[tensor_index] for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): # get_device_tensors returns row major, so we need to select the correct golden tensor - output_tensor_ = output_tensor[i // cluster_shape[cluster_axis]].unsqueeze(0).unsqueeze(0) + if cluster_axis == 0: + output_tensor_ = output_tensor[i % cluster_shape[not (cluster_axis)]].unsqueeze(0).unsqueeze(0) + else: + output_tensor_ = output_tensor[i // cluster_shape[cluster_axis]].unsqueeze(0).unsqueeze(0) + tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() logger.info(f"Checking for device {t.device().id()}") @@ -199,9 +204,12 @@ def run_op(): @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( - "output_shape, cluster_axis, num_links", + "output_shape, cluster_axis, num_links, input_num_cores, output_num_cores", [ - ([1, 1, 32, 3840], 1, 1), + ([1, 1, 32, 1536], 1, 1, 24, 8), # QKV all reduce + ([1, 1, 32, 3840], 1, 1, 24, 24), # FF1 all reduce + # TODO: Use unpadded shapes and output to 16 + ([1, 1, 32, 2304], 0, 1, 24, 8), # FF2/DO all reduce ], ) @pytest.mark.parametrize( @@ -231,6 +239,8 @@ def test_all_reduce( cluster_axis, input_dtype, num_links, + input_num_cores, + output_num_cores, num_iters, enable_async, use_program_cache, @@ -242,6 +252,8 @@ def test_all_reduce( cluster_axis, input_dtype, num_links, + input_num_cores, + output_num_cores, num_iters=num_iters, enable_async=enable_async, ) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index 583214243e5..1f19f2c26e5 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -122,24 +122,23 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers CreateCircularBuffer(program, sender_worker_core_range, cb_reserved_packet_header_config); // Reduction kernel stuff - auto all_cores = input_tensor_cores.merge(sender_worker_core_range); + auto all_cores = output_tensor_cores.merge(sender_worker_core_range); auto reduction_semaphore_id = tt::tt_metal::CreateSemaphore(program, all_cores, 0); /* reduction cb */ - uint32_t reduction_CB_single_tile_size = input_tensor.get_tensor_spec().tile().get_tile_size(df); - uint32_t reduction_CB_tiles = input_tensor_shard_num_pages * ring_size; + uint32_t reduction_CB_single_tile_size = output_tensor.get_tensor_spec().tile().get_tile_size(df); + uint32_t reduction_CB_tiles = output_tensor_shard_num_pages * ring_size; uint32_t reduction_CB_size = reduction_CB_tiles * reduction_CB_single_tile_size; uint32_t reduction_cb_index = tt::CBIndex::c_1; tt::tt_metal::CircularBufferConfig reduction_cb_config = tt::tt_metal::CircularBufferConfig(reduction_CB_size, {{reduction_cb_index, df}}) .set_page_size(reduction_cb_index, reduction_CB_single_tile_size); - // .set_globally_allocated_address(*output_tensor.buffer()); // TODO: Remove once new cb attached for output auto cb_reduction = tt::tt_metal::CreateCircularBuffer(program, all_cores, reduction_cb_config); /* out cb */ - uint32_t out_CB_single_tile_size = input_tensor.get_tensor_spec().tile().get_tile_size(df); - uint32_t out_CB_tiles = input_tensor_shard_num_pages; + uint32_t out_CB_single_tile_size = output_tensor.get_tensor_spec().tile().get_tile_size(df); + uint32_t out_CB_tiles = output_tensor_shard_num_pages; uint32_t out_CB_size = out_CB_tiles * out_CB_single_tile_size; uint32_t out_cb_index = tt::CBIndex::c_2; @@ -148,7 +147,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers .set_page_size(out_cb_index, out_CB_single_tile_size) .set_globally_allocated_address(*output_tensor.buffer()); // TODO: Remove once new cb attached for output auto cb_out = tt::tt_metal::CreateCircularBuffer( - program, input_tensor_cores, out_cb_config); // TODO: This should be the output cores instead + program, output_tensor_cores, out_cb_config); // TODO: This should be the output cores instead // Create reduction dataflow kernel auto reduction_reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; @@ -161,7 +160,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers program, "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" "reduction_dataflow.cpp", - input_tensor_cores, + output_tensor_cores, reduction_reader_kernel_config); // Create reduction dataflow kernel @@ -174,13 +173,13 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers program, "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" "eltwise_binary_kernel.cpp", - input_tensor_cores, + output_tensor_cores, reduction_kernel_config); std::vector reduction_kernel_rt_args = { - ring_size, // num_blocks - input_tensor_shard_num_pages, // block_num_tiles + ring_size, // num_blocks + output_tensor_shard_num_pages, // block_num_tiles }; - tt::tt_metal::SetRuntimeArgs(program, reduction_kernel_id, input_tensor_cores, reduction_kernel_rt_args); + tt::tt_metal::SetRuntimeArgs(program, reduction_kernel_id, output_tensor_cores, reduction_kernel_rt_args); // KERNEL CREATION // Reader @@ -230,7 +229,6 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers // semaphore auto input_cores_vec = corerange_to_cores(input_tensor_cores, std::nullopt, true); auto output_cores_vec = corerange_to_cores(output_tensor_cores, std::nullopt, true); - auto cores_per_device = output_cores_vec.size() / ring_size; for (uint32_t link = 0; link < num_links; link++) { CoreCoord core = sender_worker_cores[link]; @@ -247,6 +245,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers std::vector input_tensor_cores_x; std::vector input_tensor_cores_y; + std::vector output_tensor_cores_x; + std::vector output_tensor_cores_y; for (uint32_t i = input_tile_id_start / input_tensor_shard_num_pages; i < (input_tile_id_end + input_tensor_shard_num_pages - 1) / input_tensor_shard_num_pages; i++) { @@ -254,6 +254,13 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers input_tensor_cores_x.push_back(this_core.x); input_tensor_cores_y.push_back(this_core.y); } + for (uint32_t i = input_tile_id_start / output_tensor_shard_num_pages; + i < (input_tile_id_end + output_tensor_shard_num_pages - 1) / output_tensor_shard_num_pages; + i++) { + auto this_core = device->worker_core_from_logical_core(output_cores_vec[i]); + output_tensor_cores_x.push_back(this_core.x); + output_tensor_cores_y.push_back(this_core.y); + } tt::log_debug(tt::LogOp, "input_tile_id_start: {}", input_tile_id_start); tt::log_debug(tt::LogOp, "input_tile_id_end: {}", input_tile_id_end); @@ -303,10 +310,10 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers uint32_t out_ready_sem_wait_value = ring_size * num_links; std::vector writer_rt_args = { reduction_cb_index, // tensor_address0 - input_tensor_shard_num_pages, // num_tiles_per_core + output_tensor_shard_num_pages, // num_tiles_per_core worker_num_tiles_to_read, // num_tiles_to_read output_first_core_tile_start_offset, // first_core_tile_start_offset - input_tensor_cores_x.size(), // num_cores + output_tensor_cores_x.size(), // num_cores wait_output_semaphore, // wait_output_semaphore reset_global_semaphore, // reset_global_semaphore semaphore.address(), // out_ready_sem_bank_addr (absolute address) @@ -318,8 +325,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers mcast_end_core.x, // mcast_dest_noc_end_x mcast_end_core.y, // mcast_dest_noc_end_y }; - writer_rt_args.insert(writer_rt_args.end(), input_tensor_cores_x.begin(), input_tensor_cores_x.end()); - writer_rt_args.insert(writer_rt_args.end(), input_tensor_cores_y.begin(), input_tensor_cores_y.end()); + writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_x.begin(), output_tensor_cores_x.end()); + writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_y.begin(), output_tensor_cores_y.end()); log_trace(tt::LogOp, "Writer Runtime Args:"); for (const auto& arg : writer_rt_args) { log_trace(tt::LogOp, "\t{}", arg); From 31119ef77c91143ef59bca705e0c9d2b8a8001ed Mon Sep 17 00:00:00 2001 From: avoraTT Date: Wed, 5 Feb 2025 07:40:05 -0600 Subject: [PATCH 11/25] Add support for unpadded shapes. --- .../ttnn/unit_tests/operations/ccl/test_new_all_reduce.py | 8 ++++---- .../device/all_reduce_async_program_minimal_variants.cpp | 2 +- .../kernels/llama_post_binary_matmul_shape_reader.cpp | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 49b457c1ec1..281561fdc81 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -121,6 +121,7 @@ def run_all_reduce_impl( ), ) + logger.info(f"Input shape: {input_shape[2:]}, Padded shape: {[M, N_per_shard * input_num_cores]}") input_tensor = torch.randn(input_shape) tt_input_tensor = ttnn.from_torch( input_tensor, @@ -206,10 +207,9 @@ def run_op(): @pytest.mark.parametrize( "output_shape, cluster_axis, num_links, input_num_cores, output_num_cores", [ - ([1, 1, 32, 1536], 1, 1, 24, 8), # QKV all reduce - ([1, 1, 32, 3840], 1, 1, 24, 24), # FF1 all reduce - # TODO: Use unpadded shapes and output to 16 - ([1, 1, 32, 2304], 0, 1, 24, 8), # FF2/DO all reduce + ([1, 1, 32, 1280], 1, 1, 24, 8), # QKV all reduce + ([1, 1, 32, 3584], 1, 1, 24, 24), # FF1 all reduce + ([1, 1, 32, 2048], 0, 1, 24, 16), # FF2/DO all reduce ], ) @pytest.mark.parametrize( diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index 1f19f2c26e5..34c244fe1c6 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -240,7 +240,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers uint32_t input_tile_id_end = (link + 1) * base_pages_per_worker + std::min(link + 1, remainder); uint32_t worker_num_tiles_to_read = input_tile_id_end - input_tile_id_start; - uint32_t input_first_core_tile_start_offset = worker_num_tiles_to_read % input_tensor_shard_num_pages; + uint32_t input_first_core_tile_start_offset = 0; // worker_num_tiles_to_read % input_tensor_shard_num_pages; uint32_t output_first_core_tile_start_offset = 0; // worker_num_tiles_to_read % output_tensor_shard_num_pages; std::vector input_tensor_cores_x; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp index 4ac027b2ac7..f50361e378e 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp @@ -75,7 +75,7 @@ void kernel_main() { DPRINT << "tensor -> CB: " << (uint32_t)cb0_id << "\n"; uint32_t tiles_read = 0; - uint32_t shard_tile_id = first_core_tile_start_offset; + uint32_t shard_tile_id = 0; // first_core_tile_start_offset; uint32_t core_id = 0; while (tiles_read < num_tiles_to_read) { DPRINT << "tiles_read: " << tiles_read << "\n"; From 24e65bb245f7e4ef726aff29cd655e018f831a65 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Wed, 5 Feb 2025 07:41:30 -0600 Subject: [PATCH 12/25] Remove dprints. --- .../llama_post_binary_matmul_shape_reader.cpp | 21 --------- .../llama_post_binary_matmul_shape_writer.cpp | 47 ------------------- 2 files changed, 68 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp index f50361e378e..32f1ef2ff34 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp @@ -53,27 +53,8 @@ void kernel_main() { tt_l1_ptr uint32_t* core_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); arg_idx += num_cores; - // print every compile and runtime arg in uint32_t - DPRINT << "ct args: \n"; - DPRINT << "my_chip_id: " << (uint32_t)my_chip_id << "\n"; - DPRINT << "cb0_id: " << (uint32_t)cb0_id << "\n"; - DPRINT << "tensor0_page_size: " << (uint32_t)tensor0_page_size << "\n"; - - DPRINT << "rt args: \n"; - DPRINT << "tensor_address0: " << (uint32_t)tensor_address0 << "\n"; - DPRINT << "num_tiles_per_core: " << (uint32_t)num_tiles_per_core << "\n"; - DPRINT << "num_tiles_to_read: " << (uint32_t)num_tiles_to_read << "\n"; - DPRINT << "first_core_tile_start_offset: " << (uint32_t)first_core_tile_start_offset << "\n"; - DPRINT << "num_cores: " << (uint32_t)num_cores << "\n"; - for (uint32_t i = 0; i < num_cores; i++) { - DPRINT << "core_noc_x[" << i << "]: " << (uint32_t)core_noc_x[i] << "\n"; - DPRINT << "core_noc_y[" << i << "]: " << (uint32_t)core_noc_y[i] << "\n"; - } - // interleaved addrgen - DPRINT << "tensor -> CB: " << (uint32_t)cb0_id << "\n"; - uint32_t tiles_read = 0; uint32_t shard_tile_id = 0; // first_core_tile_start_offset; uint32_t core_id = 0; @@ -94,6 +75,4 @@ void kernel_main() { shard_tile_id = 0; core_id++; } - - DPRINT << "DONE \n"; } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp index bca5b6aeb8e..80c9756ce16 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp @@ -122,40 +122,6 @@ void kernel_main() { size_t arg_for_fab = arg_idx; auto fabric_connection = FabricConnectionManager::build_from_args(arg_idx); - DPRINT << "ct args: \n"; - DPRINT << "my_chip_id: " << (uint32_t)my_chip_id << "\n"; - DPRINT << "reserved_packet_header_cb_id: " << (uint32_t)reserved_packet_header_cb_id << "\n"; - DPRINT << "num_packet_headers_storable: " << (uint32_t)num_packet_headers_storable << "\n"; - DPRINT << "cb0_id: " << (uint32_t)cb0_id << "\n"; - DPRINT << "packet_size_in_pages: " << (uint32_t)packet_size_in_pages << "\n"; - DPRINT << "tensor0_page_size: " << (uint32_t)tensor0_page_size << "\n"; - DPRINT << "num_targets_forward_direction: " << (uint32_t)num_targets_forward_direction << "\n"; - DPRINT << "num_targets_backward_direction: " << (uint32_t)num_targets_backward_direction << "\n"; - - DPRINT << "rt args: \n"; - DPRINT << "tensor_address0: " << (uint32_t)tensor_address0 << "\n"; - DPRINT << "num_tiles_per_core: " << (uint32_t)num_tiles_per_core << "\n"; - DPRINT << "num_tiles_to_read: " << (uint32_t)num_tiles_to_read << "\n"; - DPRINT << "first_core_tile_start_offset: " << (uint32_t)first_core_tile_start_offset << "\n"; - DPRINT << "num_cores: " << (uint32_t)num_cores << "\n"; - for (uint32_t i = 0; i < num_cores; i++) { - DPRINT << "core_noc_x[" << i << "]: " << (uint32_t)core_noc_x[i] << "\n"; - DPRINT << "core_noc_y[" << i << "]: " << (uint32_t)core_noc_y[i] << "\n"; - } - DPRINT << "wait_output_semaphore: " << (uint32_t)wait_output_semaphore << "\n"; - DPRINT << "reset_global_semaphore: " << (uint32_t)reset_global_semaphore << "\n"; - DPRINT << "out_ready_sem_bank_addr: " << (uint32_t)out_ready_sem_bank_addr << "\n"; - DPRINT << "out_ready_sem_noc0_x: " << (uint32_t)out_ready_sem_noc0_x << "\n"; - DPRINT << "out_ready_sem_noc0_y: " << (uint32_t)out_ready_sem_noc0_y << "\n"; - DPRINT << "out_ready_sem_wait_value: " << (uint32_t)out_ready_sem_wait_value << "\n"; - - DPRINT << "arg_for_fab: " << (uint32_t)arg_for_fab << "\n"; - DPRINT << "fabric_connection arg 0" << get_arg_val(arg_for_fab++) << "\n"; - DPRINT << "fabric_connection arg 1" << get_arg_val(arg_for_fab++) << "\n"; - DPRINT << "fabric_connection arg 2" << get_arg_val(arg_for_fab++) << "\n"; - DPRINT << "fabric_connection arg 3" << get_arg_val(arg_for_fab++) << "\n"; - DPRINT << "fabric_connection arg 4" << get_arg_val(arg_for_fab++) << "\n"; - // packet header cb cb_reserve_back(reserved_packet_header_cb_id, num_packet_headers_storable); auto packet_header_buffer_addr = get_write_ptr(reserved_packet_header_cb_id); @@ -171,7 +137,6 @@ void kernel_main() { uint32_t writer_chip_offset = my_chip_id * num_tiles_per_core * tensor0_page_size; while (tiles_read < num_tiles_to_read) { - DPRINT << "tiles_read: " << tiles_read << "\n"; uint32_t num_tiles_to_read_this_core = std::min(num_tiles_per_core - shard_tile_id, packet_size_in_pages); num_tiles_to_read_this_core = 1; // std::min(num_tiles_to_read-tiles_read, num_tiles_to_read_this_core); cb_wait_front(cb0_id, num_tiles_to_read_this_core); @@ -183,16 +148,8 @@ void kernel_main() { // Offset the writer chip offset // noc0_dest_noc_addr += writer_chip_offset; - DPRINT << "core_noc_x[core_id]: " << (uint32_t)core_noc_x[core_id] << "\n"; - DPRINT << "core_noc_y[core_id]: " << (uint32_t)core_noc_y[core_id] << "\n"; - DPRINT << "noc0_dest_noc_addr_base: " << noc0_dest_noc_addr << "\n"; noc0_dest_noc_addr += shard_tile_id * tensor0_page_size; - DPRINT << "core_id: " << core_id << "\n"; - DPRINT << "num_tiles_to_read_this_core: " << num_tiles_to_read_this_core << "\n"; - DPRINT << "noc0_dest_noc_addr: " << noc0_dest_noc_addr << "\n"; - DPRINT << "shard_tile_id: " << shard_tile_id << "\n"; - write_and_advance_local_read_address_for_fabric_write( noc0_dest_noc_addr, packet_header_buffer_addr, @@ -241,12 +198,10 @@ void kernel_main() { uint64_t out_ready_sem_noc_addr = safe_get_noc_addr(out_ready_sem_noc0_x, out_ready_sem_noc0_y, out_ready_sem_bank_addr); noc_semaphore_inc(out_ready_sem_noc_addr, 1); - DPRINT << "inc done\n"; // 3. wait for mcast output ready semaphore if (wait_output_semaphore) { while (*reinterpret_cast(out_ready_sem_bank_addr) < out_ready_sem_wait_value); - DPRINT << "waitval done\n"; // Signal the reduction workers const uint64_t reduction_semaphore_recv_noc_addr = get_noc_multicast_addr( @@ -269,7 +224,6 @@ void kernel_main() { if (reset_global_semaphore) { const uint64_t dest_noc_addr = get_noc_addr(my_x[0], my_y[0], out_ready_sem_bank_addr); noc_inline_dw_write(dest_noc_addr, 0); - DPRINT << "reset done\n"; } if (fabric_connection.is_logically_connected()) { @@ -277,5 +231,4 @@ void kernel_main() { } noc_async_write_barrier(); - DPRINT << "DONE \n"; } From 72510151a5f0c1040951d3ee79d324e684c2ac58 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Wed, 5 Feb 2025 11:38:01 -0600 Subject: [PATCH 13/25] Fix bug in mcast bbox. Fix QKV output num cores. --- tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py | 2 +- .../device/all_reduce_async_program_minimal_variants.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 281561fdc81..d2047080ec6 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -207,7 +207,7 @@ def run_op(): @pytest.mark.parametrize( "output_shape, cluster_axis, num_links, input_num_cores, output_num_cores", [ - ([1, 1, 32, 1280], 1, 1, 24, 8), # QKV all reduce + ([1, 1, 32, 1280], 1, 1, 24, 40), # QKV all reduce ([1, 1, 32, 3584], 1, 1, 24, 24), # FF1 all reduce ([1, 1, 32, 2048], 0, 1, 24, 16), # FF2/DO all reduce ], diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index 34c244fe1c6..2b8986f444d 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -302,8 +302,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers tt::tt_metal::SetRuntimeArgs(program, worker_sender_reader_kernel_id, {core}, reader_rt_args); // Set writer runtime args - auto mcast_start_core = device->worker_core_from_logical_core(input_tensor_cores.bounding_box().start_coord); - auto mcast_end_core = device->worker_core_from_logical_core(input_tensor_cores.bounding_box().end_coord); + auto mcast_start_core = device->worker_core_from_logical_core(output_tensor_cores.bounding_box().start_coord); + auto mcast_end_core = device->worker_core_from_logical_core(output_tensor_cores.bounding_box().end_coord); bool wait_output_semaphore = (link == 0) && !enable_async_output_tensor; bool reset_global_semaphore = (link == 0) && !enable_async_output_tensor; From b67ea6da968e5572e7b9a2a526032cface3a87d3 Mon Sep 17 00:00:00 2001 From: kpaigwar Date: Thu, 6 Feb 2025 09:47:18 +0000 Subject: [PATCH 14/25] #0: multi-link support added for 3 all_reduce. Link=3 fails with kernel error Cannot add semaphore on core (x=0,y=0). Max number of semaphores (8) reached./build_metal.sh --debug Link=2 hangs at reduction signal wait only for second worker --- .../operations/ccl/test_new_all_reduce.py | 10 +- ..._reduce_async_program_minimal_variants.cpp | 101 ++++++++++++++---- .../llama_post_binary_matmul_shape_writer.cpp | 9 +- .../device/kernels/reduction_dataflow.cpp | 8 +- 4 files changed, 98 insertions(+), 30 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index d2047080ec6..45bf2eb4d2b 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -77,7 +77,7 @@ def run_all_reduce_impl( ################################## ##### FF2 Case ##### - core_offset = 1 + core_offset = 8 M, N = output_shape[2:] N_per_shard = round_up(math.ceil(N / input_num_cores), ttnn.TILE_SIZE) output_N_per_shard = round_up(math.ceil(N / output_num_cores), ttnn.TILE_SIZE) @@ -207,16 +207,16 @@ def run_op(): @pytest.mark.parametrize( "output_shape, cluster_axis, num_links, input_num_cores, output_num_cores", [ - ([1, 1, 32, 1280], 1, 1, 24, 40), # QKV all reduce - ([1, 1, 32, 3584], 1, 1, 24, 24), # FF1 all reduce - ([1, 1, 32, 2048], 0, 1, 24, 16), # FF2/DO all reduce + # ([1, 1, 32, 1280], 1, 1, 24, 40), # QKV all reduce + ([1, 1, 32, 3584], 1, 2, 24, 24), # FF1 all reduce + # ([1, 1, 32, 2048], 0, 1, 24, 16), # FF2/DO all reduce ], ) @pytest.mark.parametrize( "input_dtype", [ ttnn.bfloat16, - ttnn.bfloat8_b, + # ttnn.bfloat8_b, ], ) @pytest.mark.parametrize("num_iters", [1]) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index 2b8986f444d..30fc6de0152 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -122,8 +122,13 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers CreateCircularBuffer(program, sender_worker_core_range, cb_reserved_packet_header_config); // Reduction kernel stuff - auto all_cores = output_tensor_cores.merge(sender_worker_core_range); - auto reduction_semaphore_id = tt::tt_metal::CreateSemaphore(program, all_cores, 0); + const auto all_cores = device->worker_cores(HalProgrammableCoreType::TENSIX, device->get_sub_device_ids().at(0)); + // Create reduction semaphore vector for each link + std::vector reduction_semaphore_ids; + reduction_semaphore_ids.reserve(num_links); + for (uint32_t i = 0; i < num_links; i++) { + reduction_semaphore_ids.push_back(tt::tt_metal::CreateSemaphore(program, all_cores, 0)); + } /* reduction cb */ uint32_t reduction_CB_single_tile_size = output_tensor.get_tensor_spec().tile().get_tile_size(df); @@ -152,9 +157,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers // Create reduction dataflow kernel auto reduction_reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; reduction_reader_kernel_config.compile_args = { - reduction_cb_index, // reduction_cb_index - reduction_CB_tiles, // total_num_reduction_tiles - reduction_semaphore_id, // signal_semaphore_addr + reduction_cb_index, // reduction_cb_index + reduction_CB_tiles, // total_num_reduction_tiles }; auto reduction_reader_kernel_id = tt::tt_metal::CreateKernel( program, @@ -211,7 +215,6 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers op_config.get_page_size(), // tensor0_page_size num_targets_forward, // num_targets_forward_direction num_targets_backward, // num_targets_backward_direction - reduction_semaphore_id, // reduction_semaphore_send_addr }; log_trace(tt::LogOp, "Writer Compile Args:"); for (const auto& arg : writer_kernel_config.compile_args) { @@ -224,22 +227,44 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers sender_worker_core_range, writer_kernel_config); + // Per link Calculations + std::vector num_output_cores_in_link(num_links); + std::vector output_tensor_pages_in_link(num_links); + uint32_t total_output_cores = output_tensor_cores.num_cores(); + uint32_t output_cores_per_link = total_output_cores / num_links; + uint32_t remainder = total_output_cores % num_links; + uint32_t remaining_output_cores = total_output_cores; + uint32_t remaining_output_pages = input_tensor_num_pages; + + for (uint32_t link = 0; link < num_links - 1; link++) { + if (remainder == 0 || output_cores_per_link % 2 == 0) { + num_output_cores_in_link[link] = output_cores_per_link; + } else { + num_output_cores_in_link[link] = output_cores_per_link + 1; + } + remaining_output_cores -= num_output_cores_in_link[link]; + output_tensor_pages_in_link[link] = num_output_cores_in_link[link] * output_tensor_shard_num_pages; + remaining_output_pages -= output_tensor_pages_in_link[link]; + } + TT_FATAL(remaining_output_cores > 0, "remaining output cores for last link should be greater than 0"); + num_output_cores_in_link[num_links - 1] = remaining_output_cores; + output_tensor_pages_in_link[num_links - 1] = remaining_output_pages; + std::cout << "num_output_cores_in_link[last]: " << num_output_cores_in_link[num_links - 1] << std::endl; + // Kernel Runtime Args CoreCoord drain_sync_core; // the first worker of each chip is the drain sync core, which contains the output ready // semaphore auto input_cores_vec = corerange_to_cores(input_tensor_cores, std::nullopt, true); auto output_cores_vec = corerange_to_cores(output_tensor_cores, std::nullopt, true); + uint32_t worker_output_core_start_idx = 0; + uint32_t worker_input_core_start_idx = 0; for (uint32_t link = 0; link < num_links; link++) { CoreCoord core = sender_worker_cores[link]; // construct input and output core x and y - uint32_t base_pages_per_worker = input_tensor_num_pages / num_links; - uint32_t remainder = input_tensor_num_pages % num_links; - uint32_t input_tile_id_start = link * base_pages_per_worker + std::min(link, remainder); - uint32_t input_tile_id_end = (link + 1) * base_pages_per_worker + std::min(link + 1, remainder); - uint32_t worker_num_tiles_to_read = input_tile_id_end - input_tile_id_start; + uint32_t worker_num_tiles_to_read = output_tensor_pages_in_link[link]; uint32_t input_first_core_tile_start_offset = 0; // worker_num_tiles_to_read % input_tensor_shard_num_pages; uint32_t output_first_core_tile_start_offset = 0; // worker_num_tiles_to_read % output_tensor_shard_num_pages; @@ -247,23 +272,36 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers std::vector input_tensor_cores_y; std::vector output_tensor_cores_x; std::vector output_tensor_cores_y; - for (uint32_t i = input_tile_id_start / input_tensor_shard_num_pages; - i < (input_tile_id_end + input_tensor_shard_num_pages - 1) / input_tensor_shard_num_pages; - i++) { - auto this_core = device->worker_core_from_logical_core(input_cores_vec[i]); - input_tensor_cores_x.push_back(this_core.x); - input_tensor_cores_y.push_back(this_core.y); + std::vector output_coreranges_this_link; + output_coreranges_this_link.reserve(num_output_cores_in_link[link]); + if (link < num_links - 1) { + TT_FATAL( + worker_num_tiles_to_read % input_tensor_shard_num_pages == 0, + "worker_num_tiles_to_read must be divisible by input_tensor_shard_num_pages, currently shard tile " + "offset is not supported"); } - for (uint32_t i = input_tile_id_start / output_tensor_shard_num_pages; - i < (input_tile_id_end + output_tensor_shard_num_pages - 1) / output_tensor_shard_num_pages; + + for (uint32_t i = worker_output_core_start_idx; + i < worker_output_core_start_idx + num_output_cores_in_link[link]; i++) { + output_coreranges_this_link.push_back(CoreRange(output_cores_vec[i])); auto this_core = device->worker_core_from_logical_core(output_cores_vec[i]); output_tensor_cores_x.push_back(this_core.x); output_tensor_cores_y.push_back(this_core.y); } + worker_output_core_start_idx += num_output_cores_in_link[link]; + + for (uint32_t i = worker_input_core_start_idx; + i < worker_input_core_start_idx + + (worker_num_tiles_to_read + input_tensor_shard_num_pages - 1) / input_tensor_shard_num_pages; + i++) { + auto this_core = device->worker_core_from_logical_core(input_cores_vec[i]); + input_tensor_cores_x.push_back(this_core.x); + input_tensor_cores_y.push_back(this_core.y); + } + worker_input_core_start_idx += + (worker_num_tiles_to_read + input_tensor_shard_num_pages - 1) / input_tensor_shard_num_pages; - tt::log_debug(tt::LogOp, "input_tile_id_start: {}", input_tile_id_start); - tt::log_debug(tt::LogOp, "input_tile_id_end: {}", input_tile_id_end); tt::log_debug(tt::LogOp, "worker_num_tiles_to_read: {}", worker_num_tiles_to_read); tt::log_debug(tt::LogOp, "input_first_core_tile_start_offset: {}", input_first_core_tile_start_offset); tt::log_debug(tt::LogOp, "output_first_core_tile_start_offset: {}", output_first_core_tile_start_offset); @@ -302,8 +340,18 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers tt::tt_metal::SetRuntimeArgs(program, worker_sender_reader_kernel_id, {core}, reader_rt_args); // Set writer runtime args - auto mcast_start_core = device->worker_core_from_logical_core(output_tensor_cores.bounding_box().start_coord); - auto mcast_end_core = device->worker_core_from_logical_core(output_tensor_cores.bounding_box().end_coord); + CoreRangeSet output_crs_this_link(std::move(output_coreranges_this_link)); + output_crs_this_link = output_crs_this_link.merge_ranges(); + auto mcast_start_core = device->worker_core_from_logical_core(output_crs_this_link.bounding_box().start_coord); + auto mcast_end_core = device->worker_core_from_logical_core(output_crs_this_link.bounding_box().end_coord); + + std::cout << "output_crs_this_link num_cores: " << output_crs_this_link.num_cores() << "for link : " << link + << std::endl; + std::cout << "output_tensor_cores_x.size(): " << output_tensor_cores_x.size() << std::endl; + std::cout << "mcast_start_core.x : " << output_crs_this_link.bounding_box().start_coord.x + << " , mcast start y : " << output_crs_this_link.bounding_box().start_coord.y << std::endl; + std::cout << "mcast_end_core.x : " << output_crs_this_link.bounding_box().end_coord.x + << " , mcast end y : " << output_crs_this_link.bounding_box().end_coord.y << std::endl; bool wait_output_semaphore = (link == 0) && !enable_async_output_tensor; bool reset_global_semaphore = (link == 0) && !enable_async_output_tensor; @@ -320,6 +368,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers drain_sync_core.x, // out_ready_sem_noc0_x drain_sync_core.y, // out_ready_sem_noc0_y out_ready_sem_wait_value, // out_ready_sem_wait_value + reduction_semaphore_ids[link], // reduction_semaphore_id mcast_start_core.x, // mcast_dest_noc_start_x mcast_start_core.y, // mcast_dest_noc_start_y mcast_end_core.x, // mcast_dest_noc_end_x @@ -356,6 +405,12 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers writer_rt_args); } tt::tt_metal::SetRuntimeArgs(program, worker_sender_writer_kernel_id, {core}, writer_rt_args); + + std::vector reduction_reader_rt_args = { + reduction_semaphore_ids[link], // reduction_semaphore_id + }; + tt::tt_metal::SetRuntimeArgs( + program, reduction_reader_kernel_id, output_crs_this_link, reduction_reader_rt_args); } auto override_runtime_arguments_callback = diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp index 80c9756ce16..253135fdc11 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp @@ -89,7 +89,6 @@ void kernel_main() { /////////////////////////////////////////////////// // ARGS /////////////////////////////////////////////////// - uint32_t reduction_semaphore_send_addr = get_semaphore(get_compile_time_arg_val(8)); size_t arg_idx = 0; // Load the input tensor spec @@ -106,11 +105,14 @@ void kernel_main() { const uint8_t out_ready_sem_noc0_x = get_arg_val(arg_idx++); const uint8_t out_ready_sem_noc0_y = get_arg_val(arg_idx++); uint32_t out_ready_sem_wait_value = get_arg_val(arg_idx++); + const uint32_t reduction_semaphore_send_addr = get_semaphore(get_arg_val(arg_idx++)); const uint32_t mcast_dest_noc_start_x = get_arg_val(arg_idx++); const uint32_t mcast_dest_noc_start_y = get_arg_val(arg_idx++); const uint32_t mcast_dest_noc_end_x = get_arg_val(arg_idx++); const uint32_t mcast_dest_noc_end_y = get_arg_val(arg_idx++); + DPRINT << "reduction_output_cb_id: " << reduction_semaphore_send_addr << "\n"; + volatile tt_l1_ptr uint32_t* reduction_semaphore_send_addr_ptr = reinterpret_cast(reduction_semaphore_send_addr); noc_semaphore_set(reduction_semaphore_send_addr_ptr, VALID); @@ -199,6 +201,7 @@ void kernel_main() { safe_get_noc_addr(out_ready_sem_noc0_x, out_ready_sem_noc0_y, out_ready_sem_bank_addr); noc_semaphore_inc(out_ready_sem_noc_addr, 1); + DPRINT << "wait for output semphore \n"; // 3. wait for mcast output ready semaphore if (wait_output_semaphore) { while (*reinterpret_cast(out_ready_sem_bank_addr) < out_ready_sem_wait_value); @@ -220,6 +223,8 @@ void kernel_main() { 0); } + DPRINT << "wait done for output semphore \n"; + // 4. global semaphore reset if (reset_global_semaphore) { const uint64_t dest_noc_addr = get_noc_addr(my_x[0], my_y[0], out_ready_sem_bank_addr); @@ -231,4 +236,6 @@ void kernel_main() { } noc_async_write_barrier(); + + DPRINT << "writer done \n"; } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp index 116fd60cc1f..4773c7bd643 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp @@ -10,14 +10,20 @@ void kernel_main() { /////////////////////////////////////////////////// constexpr uint32_t cb_id = get_compile_time_arg_val(0); constexpr uint32_t total_num_reduction_tiles = get_compile_time_arg_val(1); - const uint32_t signal_semaphore_addr = get_semaphore(get_compile_time_arg_val(2)); + + // runtime args + size_t arg_idx = 0; + const uint32_t signal_semaphore_addr = get_semaphore(get_arg_val(arg_idx++)); + DPRINT << "signal_semaphore_addr: " << signal_semaphore_addr << "\n"; volatile tt_l1_ptr uint32_t* signal_semaphore_addr_ptr = reinterpret_cast(signal_semaphore_addr); + DPRINT << "WAIT for Reduction Signal \n"; // 1. Wait for signal from All-Gather worker noc_semaphore_wait(signal_semaphore_addr_ptr, VALID); noc_semaphore_set(signal_semaphore_addr_ptr, 0); + DPRINT << "Pushing Tiles \n"; // 2. Signal compute kernel to start processing cb_push_back(cb_id, total_num_reduction_tiles); From 21f1779f3aa4f76f98462f34115e24b5fe1c9c3c Mon Sep 17 00:00:00 2001 From: kpaigwar Date: Thu, 6 Feb 2025 18:21:58 +0000 Subject: [PATCH 15/25] #0: multi-link=3 works --- .../operations/ccl/test_new_all_reduce.py | 12 +++--- tt_metal/api/tt-metalium/semaphore.hpp | 2 +- ..._reduce_async_program_minimal_variants.cpp | 12 +++--- .../llama_post_binary_matmul_shape_reader.cpp | 2 +- .../llama_post_binary_matmul_shape_writer.cpp | 43 ++++++++++--------- .../device/kernels/reduction_dataflow.cpp | 7 +-- 6 files changed, 40 insertions(+), 38 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 45bf2eb4d2b..3812716a7c0 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -77,7 +77,7 @@ def run_all_reduce_impl( ################################## ##### FF2 Case ##### - core_offset = 8 + core_offset = num_links M, N = output_shape[2:] N_per_shard = round_up(math.ceil(N / input_num_cores), ttnn.TILE_SIZE) output_N_per_shard = round_up(math.ceil(N / output_num_cores), ttnn.TILE_SIZE) @@ -207,16 +207,16 @@ def run_op(): @pytest.mark.parametrize( "output_shape, cluster_axis, num_links, input_num_cores, output_num_cores", [ - # ([1, 1, 32, 1280], 1, 1, 24, 40), # QKV all reduce - ([1, 1, 32, 3584], 1, 2, 24, 24), # FF1 all reduce - # ([1, 1, 32, 2048], 0, 1, 24, 16), # FF2/DO all reduce + ([1, 1, 32, 1280], 1, 3, 24, 40), # QKV all reduce + ([1, 1, 32, 3584], 1, 3, 24, 24), # FF1 all reduce + ([1, 1, 32, 2048], 0, 3, 24, 16), # FF2/DO all reduce ], ) @pytest.mark.parametrize( "input_dtype", [ - ttnn.bfloat16, - # ttnn.bfloat8_b, + # ttnn.bfloat16, + ttnn.bfloat8_b, ], ) @pytest.mark.parametrize("num_iters", [1]) diff --git a/tt_metal/api/tt-metalium/semaphore.hpp b/tt_metal/api/tt-metalium/semaphore.hpp index 3b3ccc84b21..0356a851c8d 100644 --- a/tt_metal/api/tt-metalium/semaphore.hpp +++ b/tt_metal/api/tt-metalium/semaphore.hpp @@ -12,7 +12,7 @@ namespace tt { namespace tt_metal { -constexpr std::uint32_t NUM_SEMAPHORES = 8; +constexpr std::uint32_t NUM_SEMAPHORES = 16; class Semaphore { public: diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index 30fc6de0152..d7b87e40932 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -308,10 +308,9 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers tt::log_debug(tt::LogOp, "input_tensor_cores_x: {}", input_tensor_cores_x); tt::log_debug(tt::LogOp, "input_tensor_cores_y: {}", input_tensor_cores_y); - if (link == 0) { // drain sync core is the first worker core - drain_sync_core = device->worker_core_from_logical_core(core); - } + drain_sync_core = device->worker_core_from_logical_core(core); + std::optional forward_fabric_connection = line_topology.is_first_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) ? std::nullopt @@ -353,9 +352,9 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers std::cout << "mcast_end_core.x : " << output_crs_this_link.bounding_box().end_coord.x << " , mcast end y : " << output_crs_this_link.bounding_box().end_coord.y << std::endl; - bool wait_output_semaphore = (link == 0) && !enable_async_output_tensor; - bool reset_global_semaphore = (link == 0) && !enable_async_output_tensor; - uint32_t out_ready_sem_wait_value = ring_size * num_links; + bool wait_output_semaphore = true; //(link == 0) && !enable_async_output_tensor; + bool reset_global_semaphore = true; //(link == 0) && !enable_async_output_tensor; + uint32_t out_ready_sem_wait_value = ring_size * 1; std::vector writer_rt_args = { reduction_cb_index, // tensor_address0 output_tensor_shard_num_pages, // num_tiles_per_core @@ -373,6 +372,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers mcast_start_core.y, // mcast_dest_noc_start_y mcast_end_core.x, // mcast_dest_noc_end_x mcast_end_core.y, // mcast_dest_noc_end_y + link, // link }; writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_x.begin(), output_tensor_cores_x.end()); writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_y.begin(), output_tensor_cores_y.end()); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp index 32f1ef2ff34..61fa4f7f56c 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp @@ -59,7 +59,7 @@ void kernel_main() { uint32_t shard_tile_id = 0; // first_core_tile_start_offset; uint32_t core_id = 0; while (tiles_read < num_tiles_to_read) { - DPRINT << "tiles_read: " << tiles_read << "\n"; + // DPRINT << "tiles_read: " << tiles_read << "\n"; uint32_t num_tiles_to_read_this_core = std::min(num_tiles_per_core - shard_tile_id, num_tiles_to_read - tiles_read); cb_reserve_back(cb0_id, num_tiles_to_read_this_core); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp index 253135fdc11..16939e68178 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp @@ -110,8 +110,9 @@ void kernel_main() { const uint32_t mcast_dest_noc_start_y = get_arg_val(arg_idx++); const uint32_t mcast_dest_noc_end_x = get_arg_val(arg_idx++); const uint32_t mcast_dest_noc_end_y = get_arg_val(arg_idx++); + const uint32_t link = get_arg_val(arg_idx++); - DPRINT << "reduction_output_cb_id: " << reduction_semaphore_send_addr << "\n"; + // DPRINT << "reduction_output_cb_id: " << reduction_semaphore_send_addr << "\n"; volatile tt_l1_ptr uint32_t* reduction_semaphore_send_addr_ptr = reinterpret_cast(reduction_semaphore_send_addr); @@ -140,7 +141,7 @@ void kernel_main() { while (tiles_read < num_tiles_to_read) { uint32_t num_tiles_to_read_this_core = std::min(num_tiles_per_core - shard_tile_id, packet_size_in_pages); - num_tiles_to_read_this_core = 1; // std::min(num_tiles_to_read-tiles_read, num_tiles_to_read_this_core); + num_tiles_to_read_this_core = std::min(num_tiles_to_read - tiles_read, num_tiles_to_read_this_core); cb_wait_front(cb0_id, num_tiles_to_read_this_core); size_t l1_read_addr = get_read_ptr(cb0_id); @@ -201,29 +202,29 @@ void kernel_main() { safe_get_noc_addr(out_ready_sem_noc0_x, out_ready_sem_noc0_y, out_ready_sem_bank_addr); noc_semaphore_inc(out_ready_sem_noc_addr, 1); - DPRINT << "wait for output semphore \n"; + // DPRINT << "wait for output semphore \n"; // 3. wait for mcast output ready semaphore if (wait_output_semaphore) { while (*reinterpret_cast(out_ready_sem_bank_addr) < out_ready_sem_wait_value); - - // Signal the reduction workers - const uint64_t reduction_semaphore_recv_noc_addr = get_noc_multicast_addr( - mcast_dest_noc_start_x, - mcast_dest_noc_start_y, - mcast_dest_noc_end_x, - mcast_dest_noc_end_y, - reduction_semaphore_send_addr); - - noc_semaphore_set_multicast( - reduction_semaphore_send_addr, - reduction_semaphore_recv_noc_addr, - num_cores, - false, - false, // TODO: Why? - 0); } - DPRINT << "wait done for output semphore \n"; + // Signal the reduction workers + const uint64_t reduction_semaphore_recv_noc_addr = get_noc_multicast_addr( + mcast_dest_noc_start_x, + mcast_dest_noc_start_y, + mcast_dest_noc_end_x, + mcast_dest_noc_end_y, + reduction_semaphore_send_addr); + + noc_semaphore_set_multicast( + reduction_semaphore_send_addr, + reduction_semaphore_recv_noc_addr, + num_cores, + false, // TODO: Why? + false, // TODO: Why? + 0); + + // DPRINT << "wait done for output semphore \n"; // 4. global semaphore reset if (reset_global_semaphore) { @@ -237,5 +238,5 @@ void kernel_main() { noc_async_write_barrier(); - DPRINT << "writer done \n"; + // DPRINT << "writer done \n"; } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp index 4773c7bd643..d874af00fb9 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp @@ -14,16 +14,17 @@ void kernel_main() { // runtime args size_t arg_idx = 0; const uint32_t signal_semaphore_addr = get_semaphore(get_arg_val(arg_idx++)); - DPRINT << "signal_semaphore_addr: " << signal_semaphore_addr << "\n"; + // DPRINT << "signal_semaphore_addr: " << signal_semaphore_addr << "\n"; volatile tt_l1_ptr uint32_t* signal_semaphore_addr_ptr = reinterpret_cast(signal_semaphore_addr); - DPRINT << "WAIT for Reduction Signal \n"; + // DPRINT << "WAIT for Reduction Signal \n"; + // DPRINT << "signal_semaphore_addr_ptr: " << *signal_semaphore_addr_ptr << "\n"; // 1. Wait for signal from All-Gather worker noc_semaphore_wait(signal_semaphore_addr_ptr, VALID); noc_semaphore_set(signal_semaphore_addr_ptr, 0); - DPRINT << "Pushing Tiles \n"; + // DPRINT << "Pushing Tiles \n"; // 2. Signal compute kernel to start processing cb_push_back(cb_id, total_num_reduction_tiles); From e4a7b09b2cd143a9aced4fb6e164d4a24621d8a9 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Thu, 6 Feb 2025 13:31:02 -0600 Subject: [PATCH 16/25] Add cleanup for multi-link. --- .../operations/ccl/test_new_all_reduce.py | 8 +- ..._reduce_async_program_minimal_variants.cpp | 166 +++++++++--------- .../device/kernels/reduction_dataflow.cpp | 4 - 3 files changed, 86 insertions(+), 92 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 3812716a7c0..0e58e53a664 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -210,12 +210,18 @@ def run_op(): ([1, 1, 32, 1280], 1, 3, 24, 40), # QKV all reduce ([1, 1, 32, 3584], 1, 3, 24, 24), # FF1 all reduce ([1, 1, 32, 2048], 0, 3, 24, 16), # FF2/DO all reduce + ([1, 1, 32, 1280], 1, 2, 24, 40), # QKV all reduce + ([1, 1, 32, 3584], 1, 2, 24, 24), # FF1 all reduce + # ([1, 1, 32, 2048], 0, 2, 24, 16), # FF2/DO all reduce # Not supported + ([1, 1, 32, 1280], 1, 1, 24, 40), # QKV all reduce + ([1, 1, 32, 3584], 1, 1, 24, 24), # FF1 all reduce + ([1, 1, 32, 2048], 0, 1, 24, 16), # FF2/DO all reduce ], ) @pytest.mark.parametrize( "input_dtype", [ - # ttnn.bfloat16, + ttnn.bfloat16, ttnn.bfloat8_b, ], ) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index d7b87e40932..2f41c6db444 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -6,6 +6,7 @@ #include #include +#include #include "ttnn/tensor/tensor_impl.hpp" #include "all_reduce_async_op.hpp" #include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" @@ -33,6 +34,14 @@ namespace ttnn { using namespace ccl; +CoreRangeSet cores_to_corerangeset(const std::vector& cores) { + std::vector core_ranges; + for (const auto& core : cores) { + core_ranges.push_back(CoreRange(core)); + } + return CoreRangeSet(core_ranges); +} + operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers( const Tensor& input_tensor, std::optional forward_device, @@ -85,9 +94,12 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers const auto input_tensor_cores = input_tensor.memory_config().shard_spec->grid; const auto input_tensor_shard_shape = input_tensor.memory_config().shard_spec->shape; const auto input_tensor_shard_num_pages = input_tensor_shard_shape[0] * input_tensor_shard_shape[1] / TILE_HW; + const auto num_input_cores = input_tensor_cores.num_cores(); + const auto output_tensor_num_pages = output_tensor.buffer()->num_pages(); const auto output_tensor_cores = output_tensor.memory_config().shard_spec->grid; const auto output_tensor_shard_shape = output_tensor.memory_config().shard_spec->shape; const auto output_tensor_shard_num_pages = output_tensor_shard_shape[0] * output_tensor_shard_shape[1] / TILE_HW; + const auto num_output_cores = output_tensor_cores.num_cores(); tt::log_debug(tt::LogOp, "input_tensor_num_pages: {}", input_tensor_num_pages); tt::log_debug(tt::LogOp, "input_tensor_cores: {}", input_tensor_cores); @@ -122,12 +134,50 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers CreateCircularBuffer(program, sender_worker_core_range, cb_reserved_packet_header_config); // Reduction kernel stuff - const auto all_cores = device->worker_cores(HalProgrammableCoreType::TENSIX, device->get_sub_device_ids().at(0)); + auto all_cores = output_tensor_cores.merge(sender_worker_core_range); + auto input_cores_vec = corerange_to_cores(input_tensor_cores, std::nullopt, true); + auto output_cores_vec = corerange_to_cores(output_tensor_cores, std::nullopt, true); + + // Create output tensor splits + std::vector output_corerangeset_per_link; + std::vector num_output_cores_in_link(num_links, 0); + uint32_t output_cores_per_link = tt::div_up(output_tensor_cores.num_cores(), num_links); + uint32_t num_assigned_cores = 0; + for (uint32_t link = 0; link < num_links; link++) { + uint32_t num_cores_this_link = std::min(output_cores_per_link, num_output_cores - num_assigned_cores); + output_corerangeset_per_link.emplace_back(cores_to_corerangeset(std::vector( + output_cores_vec.begin() + num_assigned_cores, + output_cores_vec.begin() + num_assigned_cores + num_cores_this_link))); + num_output_cores_in_link[link] = num_cores_this_link; + num_assigned_cores += num_cores_this_link; + } + + // Create output tensor page splits + std::vector output_tensor_pages_in_link(num_links, 0); + uint32_t num_assigned_pages = 0; + for (uint32_t link = 0; link < num_links; link++) { + uint32_t num_output_pages_per_link = output_tensor_shard_num_pages * num_output_cores_in_link[link]; + uint32_t num_pages_this_link = + std::min(num_output_pages_per_link, output_tensor_num_pages - num_assigned_pages); + output_tensor_pages_in_link[link] = num_pages_this_link; + num_assigned_pages += num_pages_this_link; + } + + // Create input tensor splits + std::vector num_input_cores_in_link(num_links, 0); + uint32_t input_cores_per_link = + tt::div_up(output_tensor_pages_in_link[0], input_tensor_shard_num_pages); // TODO: Add validation + uint32_t num_assigned_input_cores = 0; + for (uint32_t link = 0; link < num_links; link++) { + uint32_t num_cores_this_link = std::min(input_cores_per_link, num_input_cores - num_assigned_input_cores); + num_input_cores_in_link[link] = num_cores_this_link; + num_assigned_input_cores += num_cores_this_link; + } + // Create reduction semaphore vector for each link - std::vector reduction_semaphore_ids; - reduction_semaphore_ids.reserve(num_links); - for (uint32_t i = 0; i < num_links; i++) { - reduction_semaphore_ids.push_back(tt::tt_metal::CreateSemaphore(program, all_cores, 0)); + std::vector reduction_semaphore_ids(num_links, 0); + for (uint32_t link = 0; link < num_links; link++) { + reduction_semaphore_ids[link] = tt::tt_metal::CreateSemaphore(program, all_cores, 0); } /* reduction cb */ @@ -227,53 +277,12 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers sender_worker_core_range, writer_kernel_config); - // Per link Calculations - std::vector num_output_cores_in_link(num_links); - std::vector output_tensor_pages_in_link(num_links); - uint32_t total_output_cores = output_tensor_cores.num_cores(); - uint32_t output_cores_per_link = total_output_cores / num_links; - uint32_t remainder = total_output_cores % num_links; - uint32_t remaining_output_cores = total_output_cores; - uint32_t remaining_output_pages = input_tensor_num_pages; - - for (uint32_t link = 0; link < num_links - 1; link++) { - if (remainder == 0 || output_cores_per_link % 2 == 0) { - num_output_cores_in_link[link] = output_cores_per_link; - } else { - num_output_cores_in_link[link] = output_cores_per_link + 1; - } - remaining_output_cores -= num_output_cores_in_link[link]; - output_tensor_pages_in_link[link] = num_output_cores_in_link[link] * output_tensor_shard_num_pages; - remaining_output_pages -= output_tensor_pages_in_link[link]; - } - TT_FATAL(remaining_output_cores > 0, "remaining output cores for last link should be greater than 0"); - num_output_cores_in_link[num_links - 1] = remaining_output_cores; - output_tensor_pages_in_link[num_links - 1] = remaining_output_pages; - std::cout << "num_output_cores_in_link[last]: " << num_output_cores_in_link[num_links - 1] << std::endl; - // Kernel Runtime Args - CoreCoord drain_sync_core; // the first worker of each chip is the drain sync core, which contains the output ready - // semaphore - auto input_cores_vec = corerange_to_cores(input_tensor_cores, std::nullopt, true); - auto output_cores_vec = corerange_to_cores(output_tensor_cores, std::nullopt, true); - - uint32_t worker_output_core_start_idx = 0; - uint32_t worker_input_core_start_idx = 0; for (uint32_t link = 0; link < num_links; link++) { CoreCoord core = sender_worker_cores[link]; - - // construct input and output core x and y - + CoreCoord drain_sync_core = device->worker_core_from_logical_core(core); uint32_t worker_num_tiles_to_read = output_tensor_pages_in_link[link]; - uint32_t input_first_core_tile_start_offset = 0; // worker_num_tiles_to_read % input_tensor_shard_num_pages; - uint32_t output_first_core_tile_start_offset = 0; // worker_num_tiles_to_read % output_tensor_shard_num_pages; - std::vector input_tensor_cores_x; - std::vector input_tensor_cores_y; - std::vector output_tensor_cores_x; - std::vector output_tensor_cores_y; - std::vector output_coreranges_this_link; - output_coreranges_this_link.reserve(num_output_cores_in_link[link]); if (link < num_links - 1) { TT_FATAL( worker_num_tiles_to_read % input_tensor_shard_num_pages == 0, @@ -281,35 +290,26 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers "offset is not supported"); } - for (uint32_t i = worker_output_core_start_idx; - i < worker_output_core_start_idx + num_output_cores_in_link[link]; - i++) { - output_coreranges_this_link.push_back(CoreRange(output_cores_vec[i])); - auto this_core = device->worker_core_from_logical_core(output_cores_vec[i]); - output_tensor_cores_x.push_back(this_core.x); - output_tensor_cores_y.push_back(this_core.y); - } - worker_output_core_start_idx += num_output_cores_in_link[link]; + uint32_t input_first_core_tile_start_offset = 0; + uint32_t output_first_core_tile_start_offset = 0; - for (uint32_t i = worker_input_core_start_idx; - i < worker_input_core_start_idx + - (worker_num_tiles_to_read + input_tensor_shard_num_pages - 1) / input_tensor_shard_num_pages; + std::vector input_tensor_cores_x; + std::vector input_tensor_cores_y; + std::vector output_tensor_cores_x; + std::vector output_tensor_cores_y; + for (uint32_t i = input_cores_per_link * link; i < input_cores_per_link * link + num_input_cores_in_link[link]; i++) { auto this_core = device->worker_core_from_logical_core(input_cores_vec[i]); input_tensor_cores_x.push_back(this_core.x); input_tensor_cores_y.push_back(this_core.y); } - worker_input_core_start_idx += - (worker_num_tiles_to_read + input_tensor_shard_num_pages - 1) / input_tensor_shard_num_pages; - - tt::log_debug(tt::LogOp, "worker_num_tiles_to_read: {}", worker_num_tiles_to_read); - tt::log_debug(tt::LogOp, "input_first_core_tile_start_offset: {}", input_first_core_tile_start_offset); - tt::log_debug(tt::LogOp, "output_first_core_tile_start_offset: {}", output_first_core_tile_start_offset); - tt::log_debug(tt::LogOp, "input_tensor_cores_x: {}", input_tensor_cores_x); - tt::log_debug(tt::LogOp, "input_tensor_cores_y: {}", input_tensor_cores_y); - - // drain sync core is the first worker core - drain_sync_core = device->worker_core_from_logical_core(core); + for (uint32_t i = output_cores_per_link * link; + i < output_cores_per_link * link + num_output_cores_in_link[link]; + i++) { + auto this_core = device->worker_core_from_logical_core(output_cores_vec[i]); + output_tensor_cores_x.push_back(this_core.x); + output_tensor_cores_y.push_back(this_core.y); + } std::optional forward_fabric_connection = line_topology.is_first_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) @@ -339,22 +339,14 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers tt::tt_metal::SetRuntimeArgs(program, worker_sender_reader_kernel_id, {core}, reader_rt_args); // Set writer runtime args - CoreRangeSet output_crs_this_link(std::move(output_coreranges_this_link)); - output_crs_this_link = output_crs_this_link.merge_ranges(); - auto mcast_start_core = device->worker_core_from_logical_core(output_crs_this_link.bounding_box().start_coord); - auto mcast_end_core = device->worker_core_from_logical_core(output_crs_this_link.bounding_box().end_coord); - - std::cout << "output_crs_this_link num_cores: " << output_crs_this_link.num_cores() << "for link : " << link - << std::endl; - std::cout << "output_tensor_cores_x.size(): " << output_tensor_cores_x.size() << std::endl; - std::cout << "mcast_start_core.x : " << output_crs_this_link.bounding_box().start_coord.x - << " , mcast start y : " << output_crs_this_link.bounding_box().start_coord.y << std::endl; - std::cout << "mcast_end_core.x : " << output_crs_this_link.bounding_box().end_coord.x - << " , mcast end y : " << output_crs_this_link.bounding_box().end_coord.y << std::endl; - - bool wait_output_semaphore = true; //(link == 0) && !enable_async_output_tensor; - bool reset_global_semaphore = true; //(link == 0) && !enable_async_output_tensor; - uint32_t out_ready_sem_wait_value = ring_size * 1; + auto mcast_start_core = + device->worker_core_from_logical_core(output_corerangeset_per_link[link].bounding_box().start_coord); + auto mcast_end_core = + device->worker_core_from_logical_core(output_corerangeset_per_link[link].bounding_box().end_coord); + + bool wait_output_semaphore = true; // (link == 0) && !enable_async_output_tensor; + bool reset_global_semaphore = true; // (link == 0) && !enable_async_output_tensor; + uint32_t out_ready_sem_wait_value = ring_size; std::vector writer_rt_args = { reduction_cb_index, // tensor_address0 output_tensor_shard_num_pages, // num_tiles_per_core @@ -410,7 +402,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers reduction_semaphore_ids[link], // reduction_semaphore_id }; tt::tt_metal::SetRuntimeArgs( - program, reduction_reader_kernel_id, output_crs_this_link, reduction_reader_rt_args); + program, reduction_reader_kernel_id, output_corerangeset_per_link[link], reduction_reader_rt_args); } auto override_runtime_arguments_callback = diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp index d874af00fb9..74189eeaeb0 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp @@ -14,17 +14,13 @@ void kernel_main() { // runtime args size_t arg_idx = 0; const uint32_t signal_semaphore_addr = get_semaphore(get_arg_val(arg_idx++)); - // DPRINT << "signal_semaphore_addr: " << signal_semaphore_addr << "\n"; volatile tt_l1_ptr uint32_t* signal_semaphore_addr_ptr = reinterpret_cast(signal_semaphore_addr); - // DPRINT << "WAIT for Reduction Signal \n"; - // DPRINT << "signal_semaphore_addr_ptr: " << *signal_semaphore_addr_ptr << "\n"; // 1. Wait for signal from All-Gather worker noc_semaphore_wait(signal_semaphore_addr_ptr, VALID); noc_semaphore_set(signal_semaphore_addr_ptr, 0); - // DPRINT << "Pushing Tiles \n"; // 2. Signal compute kernel to start processing cb_push_back(cb_id, total_num_reduction_tiles); From e0ce9505be5c913dfd66777f9e2b18fd1ad08231 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Thu, 6 Feb 2025 16:32:35 -0600 Subject: [PATCH 17/25] Rebase and fix/cleanup stuff. --- .../device/all_reduce_async_op.cpp | 6 +- .../device/all_reduce_async_op.hpp | 6 +- ..._reduce_async_program_minimal_variants.cpp | 45 ++++---- .../reduction.cpp} | 4 +- .../reduction_receiver.cpp} | 2 +- .../worker_reader.cpp} | 22 +--- .../worker_writer.cpp} | 101 ++++++------------ 7 files changed, 70 insertions(+), 116 deletions(-) rename ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/{eltwise_binary_kernel.cpp => compute/reduction.cpp} (94%) rename ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/{reduction_dataflow.cpp => dataflow/reduction_receiver.cpp} (94%) rename ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/{llama_post_binary_matmul_shape_reader.cpp => dataflow/worker_reader.cpp} (68%) rename ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/{llama_post_binary_matmul_shape_writer.cpp => dataflow/worker_writer.cpp} (65%) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp index 9bfaed5160f..d3f3b373b65 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 @@ -56,10 +56,10 @@ AllReduceAsync create_all_reduce_async_struct( enable_persistent_fabric_mode}; } -uint32_t find_scatter_dim(const ttnn::SimpleShape& input_tensor_padded_shape, size_t num_workers) { +uint32_t find_scatter_dim(const ttnn::Shape& input_tensor_padded_shape, size_t num_workers) { // iterate until we find a dimension that is divisible by num_workers TT_FATAL(input_tensor_padded_shape.size() == 4, "Expected input tensor to have 4 dimensions"); - ttnn::SimpleShape input_tensor_shape_in_tiles{ + ttnn::Shape input_tensor_shape_in_tiles{ input_tensor_padded_shape[0], input_tensor_padded_shape[1], input_tensor_padded_shape[2] / tt::constants::TILE_HEIGHT, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp index 05c89fa9b73..33fc3554b61 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 @@ -93,7 +93,7 @@ AllReduceAsync create_all_reduce_async_struct( std::optional sub_device_id, bool enable_persistent_fabric_mode); -uint32_t find_scatter_dim(const ttnn::SimpleShape& input_tensor_padded_shape, size_t num_workers); +uint32_t find_scatter_dim(const ttnn::Shape& input_tensor_padded_shape, size_t num_workers); } // namespace all_reduce_async_detail } // namespace ccl @@ -106,7 +106,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers const uint32_t ring_size, const uint32_t ring_index, ccl::Topology topology, - const GlobalSemaphore semaphore, + const GlobalSemaphore& semaphore, const std::optional& sub_device_id, bool enable_persistent_fabric_mode); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index 2f41c6db444..7b76a74ef9a 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 /// @@ -51,7 +51,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers const uint32_t ring_size, const uint32_t ring_index, ccl::Topology topology, - const GlobalSemaphore semaphore, + const GlobalSemaphore& semaphore, const std::optional& sub_device_id, bool enable_persistent_fabric_mode) { tt::tt_metal::Program program{}; @@ -72,7 +72,12 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers std::optional local_fabric_handle = ttnn::ccl::EdmLineFabricOpInterface::build_program_builder_worker_connection_fabric( - device, forward_device, backward_device, &program, enable_persistent_fabric_mode, num_links); + device, + forward_device.value_or(nullptr), + backward_device.value_or(nullptr), + &program, + enable_persistent_fabric_mode, + num_links); // Get OP Config, topology config std::vector input_tensors = {input_tensor}; @@ -86,8 +91,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers // Get worker cores, assuming 1 worker per link uint32_t num_workers_per_link = 1; - const auto [sender_worker_core_range, sender_worker_cores] = - choose_worker_cores(num_links, num_workers_per_link, enable_persistent_fabric_mode, device); + const auto [sender_worker_core_range, sender_worker_cores] = choose_worker_cores( + num_links, num_workers_per_link, enable_persistent_fabric_mode, device, device->get_sub_device_ids().at(0)); // Tensor Info const auto input_tensor_num_pages = input_tensor.buffer()->num_pages(); @@ -115,14 +120,14 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers uint32_t num_pages_per_packet = packet_size_bytes / l1_scratch_cb_page_size_bytes; uint32_t cb_base_num_pages = std::lcm(input_tensor_shard_num_pages, output_tensor_shard_num_pages); uint32_t cb_num_pages = std::lcm(num_pages_per_packet, cb_base_num_pages); - uint32_t src0_cb_index = tt::CB::c_in0; + uint32_t src0_cb_index = tt::CBIndex::c_0; tt::DataFormat df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig(cb_num_pages * l1_scratch_cb_page_size_bytes, {{src0_cb_index, df}}) .set_page_size(src0_cb_index, l1_scratch_cb_page_size_bytes); CBHandle cb_src0_workers = CreateCircularBuffer(program, sender_worker_core_range, cb_src0_config); // Set aside a buffer we can use for storing packet headers in (particularly for atomic incs) - const auto reserved_packet_header_CB_index = tt::CB::c_in6; + const auto reserved_packet_header_CB_index = tt::CBIndex::c_3; static constexpr auto num_packet_headers_storable = 8; static constexpr auto packet_header_size_bytes = sizeof(tt::fabric::PacketHeader); tt::tt_metal::CircularBufferConfig cb_reserved_packet_header_config = @@ -212,8 +217,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers }; auto reduction_reader_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" - "reduction_dataflow.cpp", + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/" + "reduction_receiver.cpp", output_tensor_cores, reduction_reader_kernel_config); @@ -225,8 +230,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers }; auto reduction_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" - "eltwise_binary_kernel.cpp", + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/" + "reduction.cpp", output_tensor_cores, reduction_kernel_config); std::vector reduction_kernel_rt_args = { @@ -249,8 +254,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers } auto worker_sender_reader_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" - "llama_post_binary_matmul_shape_reader.cpp", + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/" + "worker_reader.cpp", sender_worker_core_range, reader_kernel_config); @@ -272,8 +277,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers } auto worker_sender_writer_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/" - "llama_post_binary_matmul_shape_writer.cpp", + "ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/" + "worker_writer.cpp", sender_worker_core_range, writer_kernel_config); @@ -349,13 +354,13 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers uint32_t out_ready_sem_wait_value = ring_size; std::vector writer_rt_args = { reduction_cb_index, // tensor_address0 + semaphore.address(), // out_ready_sem_bank_addr (absolute address) output_tensor_shard_num_pages, // num_tiles_per_core worker_num_tiles_to_read, // num_tiles_to_read output_first_core_tile_start_offset, // first_core_tile_start_offset output_tensor_cores_x.size(), // num_cores wait_output_semaphore, // wait_output_semaphore reset_global_semaphore, // reset_global_semaphore - semaphore.address(), // out_ready_sem_bank_addr (absolute address) drain_sync_core.x, // out_ready_sem_noc0_x drain_sync_core.y, // out_ready_sem_noc0_y out_ready_sem_wait_value, // out_ready_sem_wait_value @@ -406,7 +411,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers } auto override_runtime_arguments_callback = - [worker_sender_reader_kernel_id, worker_sender_writer_kernel_id, semaphore, sender_worker_cores]( + [worker_sender_reader_kernel_id, worker_sender_writer_kernel_id, semaphore, sender_worker_cores, cb_out]( const void* operation, Program& program, const std::vector& input_tensors, @@ -415,6 +420,8 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers const auto& input = input_tensors[0]; const auto& output = output_tensors[0]; + auto semaphore = static_cast(operation)->semaphore; + // update senders auto& worker_reader_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_reader_kernel_id); auto& worker_writer_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_writer_kernel_id); @@ -424,7 +431,9 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers worker_reader_sender_runtime_args[0] = input.buffer()->address(); // writer auto& worker_writer_sender_runtime_args = worker_writer_sender_runtime_args_by_core[core.x][core.y]; - worker_writer_sender_runtime_args[0] = output.buffer()->address(); + worker_writer_sender_runtime_args[1] = semaphore.address(); + + UpdateDynamicCircularBufferAddress(program, cb_out, *output.buffer()); } }; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/eltwise_binary_kernel.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp similarity index 94% rename from ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/eltwise_binary_kernel.cpp rename to ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp index 85b85140ed8..a68fd52889b 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/eltwise_binary_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 @@ -17,7 +17,7 @@ void MAIN { const uint32_t block_num_tiles = get_arg_val(rt_args_idx++); const uint32_t copy_first_block = num_blocks % 2 != 0; - constexpr uint32_t max_dst_tiles = 8; + constexpr uint32_t max_dst_tiles = 8; // TODO: Make general cb_wait_front(cb_in0, num_blocks * block_num_tiles); cb_reserve_back(cb_out0, block_num_tiles); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/reduction_receiver.cpp similarity index 94% rename from ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp rename to ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/reduction_receiver.cpp index 74189eeaeb0..ca63befeea9 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/reduction_dataflow.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/reduction_receiver.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_reader.cpp similarity index 68% rename from ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp rename to ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_reader.cpp index 61fa4f7f56c..1f525b5d2ce 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_reader.cpp @@ -1,29 +1,15 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -// NOTE: This should ideally be merged with `ccl_send_reader` when we are able to support compile time args -// that don't require macros to function - #include "dataflow_api.h" #include -#include "cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include -#include "cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" -#include "cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" - -#include "cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp" - -#include "cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" -#include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" - -#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" -#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/io_descriptors.hpp" -#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" -#include "cpp/ttnn/tensor/enum_types.hpp" #include #include +using address_t = uint32_t; +using tt::tt_metal::BufferType; + /////////////////////////////////////////////////// // COMPILE TIME ARGS /////////////////////////////////////////////////// diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp similarity index 65% rename from ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp rename to ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp index 16939e68178..02da84d5673 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp @@ -1,29 +1,16 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -// NOTE: This should ideally be merged with `ccl_send_reader` when we are able to support compile time args -// that don't require macros to function - #include "dataflow_api.h" #include -#include "cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" -#include -#include "cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" -#include "cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" - -#include "cpp/ttnn/operations/ccl/common/kernels/command_processor.hpp" - -#include "cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp" -#include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" - -#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" -#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/io_descriptors.hpp" -#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" -#include "cpp/ttnn/tensor/enum_types.hpp" +#include "cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_ccl_common.hpp" #include #include +using address_t = uint32_t; +using tt::tt_metal::BufferType; + /////////////////////////////////////////////////// // COMPILE TIME ARGS /////////////////////////////////////////////////// @@ -37,50 +24,6 @@ constexpr uint32_t tensor0_page_size = get_compile_time_arg_val(5); constexpr uint32_t num_targets_forward_direction = get_compile_time_arg_val(6); constexpr uint32_t num_targets_backward_direction = get_compile_time_arg_val(7); -FORCE_INLINE void write_and_advance_local_read_address_for_fabric_write( - uint64_t noc0_dest_noc_addr, - size_t packet_header_buffer_addr, - uint32_t num_targets_forward_direction, - uint32_t num_targets_backward_direction, - FabricConnectionManager& fabric_connection, - size_t& l1_read_addr, - uint32_t payload_size_bytes) { - const auto [dest_noc_xy, dest_addr] = get_noc_address_components(noc0_dest_noc_addr); - const size_t payload_l1_address = l1_read_addr; - - auto pkt_hdr = reinterpret_cast(packet_header_buffer_addr); -#ifdef DEBUG_PRINT_ENABLED - pkt_hdr->reserved2 = my_chip_id; -#endif - - size_t packet_send_size_bytes = payload_size_bytes + sizeof(tt::fabric::PacketHeader); - pkt_hdr->to_write()->to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ - dest_addr, packet_send_size_bytes, static_cast(dest_noc_xy.x), static_cast(dest_noc_xy.y)}); - - noc_async_write(payload_l1_address, safe_get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr), payload_size_bytes); - if (fabric_connection.has_forward_connection()) { - pkt_hdr->to_chip_multicast( - tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_forward_direction)}); - fabric_connection.get_forward_connection().wait_for_empty_write_slot(); - fabric_connection.get_forward_connection().send_payload_without_header_non_blocking_from_address( - l1_read_addr, payload_size_bytes); - fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( - (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); - } - - if (fabric_connection.has_backward_connection()) { - pkt_hdr->to_chip_multicast( - tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); - fabric_connection.get_backward_connection().wait_for_empty_write_slot(); - fabric_connection.get_backward_connection().send_payload_without_header_non_blocking_from_address( - l1_read_addr, payload_size_bytes); - fabric_connection.get_backward_connection().send_payload_flush_blocking_from_address( - (uint32_t)pkt_hdr, sizeof(tt::fabric::PacketHeader)); - } - - l1_read_addr += payload_size_bytes; -} - /* * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) * dispatch implementations depending on those invocation parameters. @@ -95,13 +38,13 @@ void kernel_main() { uint32_t reduction_output_cb_id = get_arg_val(arg_idx++); address_t tensor_address0 = get_write_ptr(reduction_output_cb_id); + const size_t out_ready_sem_bank_addr = get_arg_val(arg_idx++); uint32_t num_tiles_per_core = get_arg_val(arg_idx++); uint32_t num_tiles_to_read = get_arg_val(arg_idx++); uint32_t first_core_tile_start_offset = get_arg_val(arg_idx++); uint32_t num_cores = get_arg_val(arg_idx++); bool wait_output_semaphore = get_arg_val(arg_idx++); bool reset_global_semaphore = get_arg_val(arg_idx++); - const size_t out_ready_sem_bank_addr = get_arg_val(arg_idx++); const uint8_t out_ready_sem_noc0_x = get_arg_val(arg_idx++); const uint8_t out_ready_sem_noc0_y = get_arg_val(arg_idx++); uint32_t out_ready_sem_wait_value = get_arg_val(arg_idx++); @@ -126,8 +69,25 @@ void kernel_main() { auto fabric_connection = FabricConnectionManager::build_from_args(arg_idx); // packet header cb - cb_reserve_back(reserved_packet_header_cb_id, num_packet_headers_storable); - auto packet_header_buffer_addr = get_write_ptr(reserved_packet_header_cb_id); + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_addr_forward = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_addr_backward = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_seminc = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + + // pre-populate packet headers + volatile tt::fabric::PacketHeader* pkt_hdr_forward = + reinterpret_cast(packet_header_buffer_addr_forward); + volatile tt::fabric::PacketHeader* pkt_hdr_backward = + reinterpret_cast(packet_header_buffer_addr_backward); + pkt_hdr_forward->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_forward_direction)}); + pkt_hdr_backward->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); if (fabric_connection.is_logically_connected()) { fabric_connection.open(); @@ -155,9 +115,8 @@ void kernel_main() { write_and_advance_local_read_address_for_fabric_write( noc0_dest_noc_addr, - packet_header_buffer_addr, - num_targets_forward_direction, - num_targets_backward_direction, + pkt_hdr_forward, + pkt_hdr_backward, fabric_connection, l1_read_addr, num_tiles_to_read_this_core * tensor0_page_size); @@ -173,7 +132,7 @@ void kernel_main() { } // 2. mcast output ready semaphore - auto* pkt_hdr = reinterpret_cast(packet_header_buffer_addr); + auto* pkt_hdr = reinterpret_cast(packet_header_buffer_seminc); pkt_hdr->to_atomic_inc(); pkt_hdr->to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader{ out_ready_sem_bank_addr, @@ -187,7 +146,7 @@ void kernel_main() { pkt_hdr->to_chip_multicast( tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_forward_direction)}); fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( - packet_header_buffer_addr, sizeof(tt::fabric::PacketHeader)); + packet_header_buffer_seminc, sizeof(tt::fabric::PacketHeader)); } // Write the mcast packet (backward) if (fabric_connection.has_backward_connection()) { @@ -195,7 +154,7 @@ void kernel_main() { tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); fabric_connection.get_backward_connection().wait_for_empty_write_slot(); fabric_connection.get_backward_connection().send_payload_non_blocking_from_address( - packet_header_buffer_addr, sizeof(tt::fabric::PacketHeader)); + packet_header_buffer_seminc, sizeof(tt::fabric::PacketHeader)); } // increment locally uint64_t out_ready_sem_noc_addr = From ea2edb7b323327e40b5ebcbdfd48cfbd10624568 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Fri, 7 Feb 2025 07:32:25 -0600 Subject: [PATCH 18/25] Clean up pytest and enable trace. --- .../operations/ccl/test_new_all_reduce.py | 53 ++++++++++++------- .../device/kernels/dataflow/worker_writer.cpp | 2 +- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 0e58e53a664..2b179fb417c 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -5,6 +5,7 @@ import torch import pytest import math +from time import time from loguru import logger import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc @@ -152,28 +153,31 @@ def run_op(): num_links=num_links, subdevice_id=worker_sub_device_id, ) - for d in mesh_device.get_devices(): - ttnn.synchronize_device(d) + if not trace_mode: + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) outs.append(out) return outs - # ##### Compile Model ##### - # logger.info("Compiling model") - # tt_outs = run_op() + if trace_mode: + ##### Compile Model ##### + logger.info("Compiling model") + tt_outs = run_op() - # ##### Capture Trace ##### - # logger.info("Capturing trace") + ##### Capture Trace ##### + logger.info("Capturing trace") - # trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) - # tt_outs = run_op() - # ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + tt_outs = run_op() + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) - # ##### Run Trace ##### - # logger.info("Running trace") - # ttnn.execute_trace(mesh_device, trace_id, blocking=False) + ##### Run Trace ##### + logger.info("Running trace") + ttnn.execute_trace(mesh_device, trace_id, blocking=False) - tt_outs = run_op() + else: + tt_outs = run_op() ################################## ##### Validation @@ -188,19 +192,29 @@ def run_op(): else: output_tensor_ = output_tensor[i // cluster_shape[cluster_axis]].unsqueeze(0).unsqueeze(0) - tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() - logger.info(f"Checking for device {t.device().id()}") + tt_output_tensor = t.cpu().to_torch() + # logger.info(f"Checking for device {t.device().id()}") if input_dtype == ttnn.bfloat16: eq, output = comp_pcc(tt_output_tensor, output_tensor_) else: eq, output = comp_pcc(tt_output_tensor, output_tensor_) - logger.info(f"PCC output for {i} is: {output}") assert eq, f"{i} FAILED: {output}" + logger.info(f"PCC output for {tensor_index} is: {output}") + + for i in range(mesh_device.get_num_devices()): + assert ( + mesh_device.get_devices()[i].num_program_cache_entries() == 1 + or mesh_device.get_devices()[i].num_program_cache_entries() == num_iters + ), f"Device {i} has {mesh_device.get_devices()[i].num_program_cache_entries()} program cache entries" + finally: if enable_persistent_fabric and teardown_persistent_fabric: mesh_device.reset_sub_device_stall_group() + t1 = time() teardown_fabric_interface(mesh_device) + t2 = time() + logger.info(f"Teardown time: {t2 - t1}") @skip_for_grayskull("Requires eth connected devices to run") @@ -225,8 +239,9 @@ def run_op(): ttnn.bfloat8_b, ], ) -@pytest.mark.parametrize("num_iters", [1]) +@pytest.mark.parametrize("num_iters", [5]) @pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("trace_mode", [True]) @pytest.mark.parametrize( "device_params", [{"trace_region_size": 23887872}], @@ -249,6 +264,7 @@ def test_all_reduce( output_num_cores, num_iters, enable_async, + trace_mode, use_program_cache, function_level_defaults, ): @@ -262,4 +278,5 @@ def test_all_reduce( output_num_cores, num_iters=num_iters, enable_async=enable_async, + trace_mode=trace_mode, ) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp index 02da84d5673..bb25661536d 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp @@ -153,7 +153,7 @@ void kernel_main() { pkt_hdr->to_chip_multicast( tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); fabric_connection.get_backward_connection().wait_for_empty_write_slot(); - fabric_connection.get_backward_connection().send_payload_non_blocking_from_address( + fabric_connection.get_backward_connection().send_payload_flush_blocking_from_address( packet_header_buffer_seminc, sizeof(tt::fabric::PacketHeader)); } // increment locally From b6277d90867b09e13d3e057c55f7d7fd677eba46 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Fri, 7 Feb 2025 08:28:07 -0600 Subject: [PATCH 19/25] Adding gsem fix for multi-iter. --- tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 2b179fb417c..f428223cce9 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -68,7 +68,9 @@ def run_all_reduce_impl( mesh_device.set_sub_device_stall_group(sub_device_stall_group) # create global semaphore handles - ccl_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) + ccl_semaphore_handles = [ + create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(8) + ] logger.info(f"Output shape: {output_shape}") @@ -147,7 +149,7 @@ def run_op(): tt_input_tensor, cluster_axis=cluster_axis, mesh_device=mesh_device, - multi_device_global_semaphore=ccl_semaphore_handles, + multi_device_global_semaphore=ccl_semaphore_handles[i % 8], memory_config=output_mem_config, topology=ttnn.Topology.Linear, num_links=num_links, From 4520362fa3be228780bd0e89abc469fc52b3f997 Mon Sep 17 00:00:00 2001 From: kpaigwar Date: Fri, 7 Feb 2025 18:53:56 +0000 Subject: [PATCH 20/25] #0: added api to subtract corerangesets --- tt_metal/api/tt-metalium/core_coord.hpp | 2 + tt_metal/common/core_coord.cpp | 72 +++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/tt_metal/api/tt-metalium/core_coord.hpp b/tt_metal/api/tt-metalium/core_coord.hpp index f7889fcc746..203ccfae36a 100644 --- a/tt_metal/api/tt-metalium/core_coord.hpp +++ b/tt_metal/api/tt-metalium/core_coord.hpp @@ -175,6 +175,8 @@ class CoreRangeSet { // code that uses this CoreRangeSet. CoreRangeSet merge_ranges() const; + CoreRangeSet subtract(const CoreRangeSet& other) const; + private: void validate_no_overlap(); diff --git a/tt_metal/common/core_coord.cpp b/tt_metal/common/core_coord.cpp index ffad3d88c7d..a9b74f075de 100644 --- a/tt_metal/common/core_coord.cpp +++ b/tt_metal/common/core_coord.cpp @@ -429,6 +429,78 @@ void CoreRangeSet::validate_no_overlap() { } } +CoreRangeSet CoreRangeSet::subtract(const CoreRangeSet& other) const { + if (other.empty()) { + return *this; + } + if (this->empty()) { + return CoreRangeSet(); + } + + if (!this->intersects(other)) { + return *this; + } + + std::vector result_ranges; + + for (const auto& current_range : this->ranges_) { + std::vector current_remaining = {current_range}; + + for (const auto& subtract_range : other.ranges_) { + std::vector new_remaining; + + for (const auto& remaining : current_remaining) { + if (!remaining.intersects(subtract_range)) { + new_remaining.push_back(remaining); + continue; + } + + auto intersection_opt = remaining.intersection(subtract_range); + if (!intersection_opt.has_value()) { + new_remaining.push_back(remaining); + continue; + } + + CoreRange intersection = intersection_opt.value(); + + if (remaining.start_coord.x < intersection.start_coord.x) { + CoreRange left{ + remaining.start_coord, CoreCoord{intersection.start_coord.x - 1, remaining.end_coord.y}}; + new_remaining.push_back(left); + } + + if (remaining.end_coord.x > intersection.end_coord.x) { + CoreRange right{ + CoreCoord{intersection.end_coord.x + 1, remaining.start_coord.y}, remaining.end_coord}; + new_remaining.push_back(right); + } + + if (remaining.start_coord.y < intersection.start_coord.y) { + CoreRange bottom{ + CoreCoord{ + std::max(remaining.start_coord.x, intersection.start_coord.x), remaining.start_coord.y}, + CoreCoord{ + std::min(remaining.end_coord.x, intersection.end_coord.x), intersection.start_coord.y - 1}}; + new_remaining.push_back(bottom); + } + + if (remaining.end_coord.y > intersection.end_coord.y) { + CoreRange top{ + CoreCoord{ + std::max(remaining.start_coord.x, intersection.start_coord.x), + intersection.end_coord.y + 1}, + CoreCoord{std::min(remaining.end_coord.x, intersection.end_coord.x), remaining.end_coord.y}}; + new_remaining.push_back(top); + } + } + current_remaining = new_remaining; + } + result_ranges.insert(result_ranges.end(), current_remaining.begin(), current_remaining.end()); + } + + return CoreRangeSet(std::move(result_ranges)); +} + bool operator==(const CoreRangeSet& a, const CoreRangeSet& b) { if (a.ranges().size() == b.ranges().size()) { auto range_a = a.ranges(); From 203e2e8a913cbc10dd470f24232a0ccb5316c249 Mon Sep 17 00:00:00 2001 From: kpaigwar Date: Fri, 7 Feb 2025 14:26:43 -0600 Subject: [PATCH 21/25] #0: updated choose_worker_cores function to omit reserved_cores --- .../ccl/all_gather_async/device/all_gather_async_op.hpp | 7 ++++++- .../all_gather_async/device/all_gather_async_program.cpp | 8 ++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp index d8193771f62..853c259e4f6 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp @@ -110,7 +110,12 @@ AllGatherAsync create_all_gather_async_struct( // All Gather Variants std::tuple> choose_worker_cores( - size_t num_links, size_t num_workers_per_link, bool persistent_fabric_mode, IDevice* device, const std::optional& sub_device_id); + size_t num_links, + size_t num_workers_per_link, + bool persistent_fabric_mode, + IDevice* device, + const std::optional& sub_device_id, + const std::optional& reserved_core_range = std::nullopt); operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( const Tensor& input_tensor, std::optional forward_device, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp index dbcc0d5848d..8a13126f038 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program.cpp @@ -77,14 +77,18 @@ std::tuple> choose_worker_cores( size_t num_workers_per_link, bool persistent_fabric_mode, IDevice* device, - const std::optional& sub_device_id) { + const std::optional& sub_device_id, + const std::optional& reserved_core_range) { std::tuple> result; CoreRangeSet sender_worker_core_range; if (persistent_fabric_mode) { const size_t num_workers_preferred = num_workers_per_link * num_links; - const auto available_cores = device->worker_cores( + auto available_cores = device->worker_cores( HalProgrammableCoreType::TENSIX, sub_device_id.has_value() ? *sub_device_id : device->get_sub_device_ids().at(0)); + if (reserved_core_range.has_value()) { + available_cores = available_cores.subtract(*reserved_core_range); + } if (available_cores.num_cores() < num_workers_preferred) { log_warning( tt::LogOp, From 694ef4d096aaa5a6cc4870f7dfe5a92f6df122ae Mon Sep 17 00:00:00 2001 From: kpaigwar Date: Fri, 7 Feb 2025 14:28:29 -0600 Subject: [PATCH 22/25] #0: fix placement of link worker cores --- .../unit_tests/operations/ccl/test_new_all_reduce.py | 5 ++--- .../all_reduce_async_program_minimal_variants.cpp | 12 ++++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index f428223cce9..a95c4c7694f 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -80,7 +80,6 @@ def run_all_reduce_impl( ################################## ##### FF2 Case ##### - core_offset = num_links M, N = output_shape[2:] N_per_shard = round_up(math.ceil(N / input_num_cores), ttnn.TILE_SIZE) output_N_per_shard = round_up(math.ceil(N / output_num_cores), ttnn.TILE_SIZE) @@ -93,7 +92,7 @@ def run_all_reduce_impl( ttnn.CoreCoord(x, y), ttnn.CoreCoord(x, y), ) - for x, y in CORE_RANGE[core_offset : core_offset + input_num_cores] + for x, y in CORE_RANGE[:input_num_cores] ] ) input_mem_config = ttnn.MemoryConfig( @@ -111,7 +110,7 @@ def run_all_reduce_impl( ttnn.CoreCoord(x, y), ttnn.CoreCoord(x, y), ) - for x, y in CORE_RANGE[core_offset : core_offset + output_num_cores] + for x, y in CORE_RANGE[:output_num_cores] ] ) output_mem_config = ttnn.MemoryConfig( diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index 7b76a74ef9a..ed2415d515f 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -88,12 +88,6 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::FORWARD); const size_t num_targets_backward = line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD); - - // Get worker cores, assuming 1 worker per link - uint32_t num_workers_per_link = 1; - const auto [sender_worker_core_range, sender_worker_cores] = choose_worker_cores( - num_links, num_workers_per_link, enable_persistent_fabric_mode, device, device->get_sub_device_ids().at(0)); - // Tensor Info const auto input_tensor_num_pages = input_tensor.buffer()->num_pages(); const auto input_tensor_cores = input_tensor.memory_config().shard_spec->grid; @@ -106,6 +100,12 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers const auto output_tensor_shard_num_pages = output_tensor_shard_shape[0] * output_tensor_shard_shape[1] / TILE_HW; const auto num_output_cores = output_tensor_cores.num_cores(); + // Get worker cores, assuming 1 worker per link + std::optional reserved_cores = output_tensor_cores; + uint32_t num_workers_per_link = 1; + const auto [sender_worker_core_range, sender_worker_cores] = choose_worker_cores( + num_links, num_workers_per_link, enable_persistent_fabric_mode, device, sub_device_id, reserved_cores); + tt::log_debug(tt::LogOp, "input_tensor_num_pages: {}", input_tensor_num_pages); tt::log_debug(tt::LogOp, "input_tensor_cores: {}", input_tensor_cores); tt::log_debug(tt::LogOp, "input_tensor_shard_shape: {}", input_tensor_shard_shape); From a13189710189f0ba4c8f342cd1789c035742d8a9 Mon Sep 17 00:00:00 2001 From: avoraTT Date: Fri, 7 Feb 2025 18:04:06 -0600 Subject: [PATCH 23/25] Add support for input shard not divisble by output shard. --- .../operations/ccl/test_new_all_reduce.py | 13 ++++- ..._reduce_async_program_minimal_variants.cpp | 55 +++++++++++++------ .../device/kernels/dataflow/worker_reader.cpp | 2 +- 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index a95c4c7694f..405b9106f3a 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -174,8 +174,16 @@ def run_op(): ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) ##### Run Trace ##### - logger.info("Running trace") + logger.info("Starting Trace perf test...") + time_start = time() ttnn.execute_trace(mesh_device, trace_id, blocking=False) + ttnn.release_trace(mesh_device, trace_id) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + time_end = time() + logger.info(f"Time taken: {time_end - time_start} s") + logger.info(f"Time per iter: {(time_end - time_start) / num_iters} s") + logger.info(f"Time per iter: {(time_end - time_start) / num_iters * 1e6} us") else: tt_outs = run_op() @@ -222,12 +230,13 @@ def run_op(): @pytest.mark.parametrize( "output_shape, cluster_axis, num_links, input_num_cores, output_num_cores", [ + ([1, 1, 32, 2048], 0, 4, 24, 16), # FF2/DO all reduce ([1, 1, 32, 1280], 1, 3, 24, 40), # QKV all reduce ([1, 1, 32, 3584], 1, 3, 24, 24), # FF1 all reduce ([1, 1, 32, 2048], 0, 3, 24, 16), # FF2/DO all reduce ([1, 1, 32, 1280], 1, 2, 24, 40), # QKV all reduce ([1, 1, 32, 3584], 1, 2, 24, 24), # FF1 all reduce - # ([1, 1, 32, 2048], 0, 2, 24, 16), # FF2/DO all reduce # Not supported + ([1, 1, 32, 2048], 0, 2, 24, 16), # FF2/DO all reduce ([1, 1, 32, 1280], 1, 1, 24, 40), # QKV all reduce ([1, 1, 32, 3584], 1, 1, 24, 24), # FF1 all reduce ([1, 1, 32, 2048], 0, 1, 24, 16), # FF2/DO all reduce diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp index ed2415d515f..adc1f988021 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_program_minimal_variants.cpp @@ -119,7 +119,7 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers uint32_t l1_scratch_cb_page_size_bytes = op_config.get_page_size(); uint32_t num_pages_per_packet = packet_size_bytes / l1_scratch_cb_page_size_bytes; uint32_t cb_base_num_pages = std::lcm(input_tensor_shard_num_pages, output_tensor_shard_num_pages); - uint32_t cb_num_pages = std::lcm(num_pages_per_packet, cb_base_num_pages); + uint32_t cb_num_pages = input_tensor_num_pages; uint32_t src0_cb_index = tt::CBIndex::c_0; tt::DataFormat df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); tt::tt_metal::CircularBufferConfig cb_src0_config = @@ -169,14 +169,38 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers } // Create input tensor splits - std::vector num_input_cores_in_link(num_links, 0); - uint32_t input_cores_per_link = - tt::div_up(output_tensor_pages_in_link[0], input_tensor_shard_num_pages); // TODO: Add validation - uint32_t num_assigned_input_cores = 0; + std::vector> input_cores_idx_per_link(num_links, {0, 0}); + std::vector input_tensor_tile_offset_per_link(num_links, 0); + uint32_t start_core_idx = 0; + uint32_t num_pages_overflow = 0; for (uint32_t link = 0; link < num_links; link++) { - uint32_t num_cores_this_link = std::min(input_cores_per_link, num_input_cores - num_assigned_input_cores); - num_input_cores_in_link[link] = num_cores_this_link; - num_assigned_input_cores += num_cores_this_link; + uint32_t num_pages_this_link = output_tensor_pages_in_link[link]; + + // Get offset based on previous overflow + uint32_t input_tensor_tile_offset = + (input_tensor_shard_num_pages - num_pages_overflow) % input_tensor_shard_num_pages; + input_tensor_tile_offset_per_link[link] = input_tensor_tile_offset; + + uint32_t end_core_idx = std::min( + start_core_idx + tt::div_up(num_pages_this_link + input_tensor_tile_offset, input_tensor_shard_num_pages), + num_input_cores - 1); + + // Num pages allocated based on number of input cores selected for this link + uint32_t num_pages_allocated = + (end_core_idx - start_core_idx) * input_tensor_shard_num_pages - input_tensor_tile_offset; + + // Update overflow + num_pages_overflow = num_pages_allocated - num_pages_this_link; + + // Store core indices + input_cores_idx_per_link[link] = {start_core_idx, end_core_idx}; + + // Set start index based on overflow + if (num_pages_overflow > 0) { + start_core_idx = end_core_idx - 1; + } else { + start_core_idx = end_core_idx; + } } // Create reduction semaphore vector for each link @@ -288,22 +312,14 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers CoreCoord drain_sync_core = device->worker_core_from_logical_core(core); uint32_t worker_num_tiles_to_read = output_tensor_pages_in_link[link]; - if (link < num_links - 1) { - TT_FATAL( - worker_num_tiles_to_read % input_tensor_shard_num_pages == 0, - "worker_num_tiles_to_read must be divisible by input_tensor_shard_num_pages, currently shard tile " - "offset is not supported"); - } - - uint32_t input_first_core_tile_start_offset = 0; + uint32_t input_first_core_tile_start_offset = input_tensor_tile_offset_per_link[link]; uint32_t output_first_core_tile_start_offset = 0; std::vector input_tensor_cores_x; std::vector input_tensor_cores_y; std::vector output_tensor_cores_x; std::vector output_tensor_cores_y; - for (uint32_t i = input_cores_per_link * link; i < input_cores_per_link * link + num_input_cores_in_link[link]; - i++) { + for (uint32_t i = input_cores_idx_per_link[link].first; i < input_cores_idx_per_link[link].second; i++) { auto this_core = device->worker_core_from_logical_core(input_cores_vec[i]); input_tensor_cores_x.push_back(this_core.x); input_tensor_cores_y.push_back(this_core.y); @@ -408,6 +424,9 @@ operation::ProgramWithCallbacks all_reduce_async_minimal_multi_core_with_workers }; tt::tt_metal::SetRuntimeArgs( program, reduction_reader_kernel_id, output_corerangeset_per_link[link], reduction_reader_rt_args); + + input_first_core_tile_start_offset = + (worker_num_tiles_to_read % input_tensor_shard_num_pages) + input_first_core_tile_start_offset; } auto override_runtime_arguments_callback = diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_reader.cpp index 1f525b5d2ce..104a020d83e 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_reader.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_reader.cpp @@ -42,7 +42,7 @@ void kernel_main() { // interleaved addrgen uint32_t tiles_read = 0; - uint32_t shard_tile_id = 0; // first_core_tile_start_offset; + uint32_t shard_tile_id = first_core_tile_start_offset; uint32_t core_id = 0; while (tiles_read < num_tiles_to_read) { // DPRINT << "tiles_read: " << tiles_read << "\n"; From bcf48a620040e2f49013b76212880072edff8ee5 Mon Sep 17 00:00:00 2001 From: Jack Cai Date: Mon, 10 Feb 2025 22:23:29 +0000 Subject: [PATCH 24/25] added all reduce into llama ccl perf test and added proper measurement of e2e trace perf --- .../ccl/test_all_gather_TG_post_commit.py | 90 ++++--- .../operations/ccl/test_ccl_async_TG_llama.py | 255 +++++------------- .../operations/ccl/test_ccl_common.py | 1 - .../operations/ccl/test_new_all_reduce.py | 89 ++++-- .../device/kernels/compute/reduction.cpp | 1 - .../device/kernels/dataflow/worker_writer.cpp | 11 +- 6 files changed, 195 insertions(+), 252 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py index 7f37600028a..2dcd7473d0c 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py @@ -64,6 +64,7 @@ def run_with_trace( n_worker=None, n_buffer=None, num_iter=20, + warmup_iters=0, use_all_gather_async=False, profiler=BenchmarkProfiler(), ): @@ -99,49 +100,68 @@ def run_with_trace( # Capture trace logger.info("Capturing trace") - trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) - for i in range(num_iter): - if use_all_gather_async: - logger.info("Running all-gather async") - tt_out_tensor = ttnn.experimental.all_gather_async( - input_tensor, - dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - topology=ttnn.Topology.Linear, - multi_device_global_semaphore=ccl_semaphore_handles[i] - if type(ccl_semaphore_handles) == list - else ccl_semaphore_handles, - num_links=num_links, - memory_config=output_mem_config, - subdevice_id=worker_sub_device_id, - enable_persistent_fabric_mode=enable_persistent_fabric, - ) - else: - tt_out_tensor = ttnn.all_gather( - input_tensor, - dim=dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - num_links=num_links, - memory_config=output_mem_config, - topology=all_gather_topology, - ) - ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) - for d in mesh_device.get_devices(): - ttnn.synchronize_device(d) + + def capture_trace(n_iters): + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + for i in range(n_iters): + if use_all_gather_async: + logger.info("Running all-gather async") + tt_out_tensor = ttnn.experimental.all_gather_async( + input_tensor, + dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + topology=ttnn.Topology.Linear, + multi_device_global_semaphore=ccl_semaphore_handles[i] + if type(ccl_semaphore_handles) == list + else ccl_semaphore_handles, + num_links=num_links, + memory_config=output_mem_config, + subdevice_id=worker_sub_device_id, + enable_persistent_fabric_mode=enable_persistent_fabric, + ) + else: + tt_out_tensor = ttnn.all_gather( + input_tensor, + dim=dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + num_links=num_links, + memory_config=output_mem_config, + topology=all_gather_topology, + ) + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + return trace_id + + if warmup_iters > 0: + trace_id_warmup = capture_trace(warmup_iters) + trace_id = capture_trace(num_iter) # Run the op logger.info("Starting Trace perf test...") + profiler.start("all-gather-async-trace-warmup") + if warmup_iters > 0: + ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False) + ttnn.release_trace(mesh_device, trace_id_warmup) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + profiler.end("all-gather-async-trace-warmup") + profiler.start("all-gather-async-trace") ttnn.execute_trace(mesh_device, trace_id, blocking=False) ttnn.release_trace(mesh_device, trace_id) for d in mesh_device.get_devices(): ttnn.synchronize_device(d) profiler.end("all-gather-async-trace") - logger.info(f"Time taken: {profiler.get_duration('all-gather-async-trace')} s") - logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter} s") - logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter * 1e6} us") + time_taken = profiler.get_duration("all-gather-async-trace") - profiler.get_duration( + "all-gather-async-trace-warmup" + ) + effective_iter = num_iter - warmup_iters + logger.info(f"Time taken: {time_taken} s") + logger.info(f"Time per iter: {time_taken / effective_iter} s") + logger.info(f"Time per iter: {time_taken / effective_iter * 1e6} us") return tt_out_tensor @@ -163,6 +183,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( output_shard_spec: ttnn.ShardSpec = None, num_all_gather_instances: int = 1, num_iters: int = 1, + warmup_iters: int = 0, cluster_axis: int = 0, tile=(32, 32), trace_mode=False, @@ -277,6 +298,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( enable_persistent_fabric=enable_persistent_fabric, all_gather_topology=ttnn.Topology.Linear, num_iter=num_iters, + warmup_iters=warmup_iters, use_all_gather_async=use_all_gather_async, profiler=profiler, ) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py index fe967467e14..0d6795a9551 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py @@ -20,8 +20,8 @@ from tests.ttnn.unit_tests.operations.ccl.test_reduce_scatter_TG_nightly import ( run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows, ) -from tests.ttnn.unit_tests.operations.ccl.test_all_reduce_async import ( - run_all_reduce_with_mesh_tensor_along_row, +from tests.ttnn.unit_tests.operations.ccl.test_new_all_reduce import ( + run_all_reduce_impl, ) from models.perf.benchmarking_utils import BenchmarkProfiler @@ -89,9 +89,9 @@ def get_core_range_set(output_core_grid): ], ) @pytest.mark.parametrize( - "num_iters", + "num_iters, warmup_iters", [ - 5000, + (500, 100), ], ) @pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) @@ -131,7 +131,7 @@ def get_core_range_set(output_core_grid): @pytest.mark.parametrize("enable_async", [True]) @pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) @pytest.mark.parametrize("device_params", [{"trace_region_size": 17068032}], indirect=True) -def test_line_all_gather_sharded_on_TG_rows_llama( +def test_all_gather_llama( mesh_device, num_devices, output_shape, @@ -150,6 +150,7 @@ def test_line_all_gather_sharded_on_TG_rows_llama( enable_async, replication_factor, num_iters, + warmup_iters, perf_target_us, ): if len(mesh_device.get_devices()) != 32: @@ -185,6 +186,7 @@ def test_line_all_gather_sharded_on_TG_rows_llama( function_level_defaults, enable_async=enable_async, num_iters=num_iters, + warmup_iters=warmup_iters, input_shard_spec=input_shard_spec, output_shard_spec=output_shard_spec, num_all_gather_instances=replication_factor, @@ -197,7 +199,13 @@ def test_line_all_gather_sharded_on_TG_rows_llama( teardown_persistent_fabric=True, ) - latency_us = profiler.get_duration("all-gather-async-trace") / num_iters * 1e6 + time_taken = profiler.get_duration("all-gather-async-trace") - profiler.get_duration( + "all-gather-async-trace-warmup" + ) + effective_iter = num_iters - warmup_iters + latency_us = time_taken / effective_iter * 1e6 + logger.info(f"Time taken: {time_taken} s") + logger.info(f"Time per iter: {latency_us} us") if perf_target_us is not None: assert ( latency_us < perf_target_us @@ -206,222 +214,81 @@ def test_line_all_gather_sharded_on_TG_rows_llama( @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( - "num_devices, num_links", + "output_shape, cluster_axis, num_links, input_num_cores, output_num_cores, perf_target_us", [ - (4, 2), + ([1, 1, 32, 2048], 0, 4, 24, 16, 84), # FF2/DO all reduce + ([1, 1, 32, 1280], 1, 3, 24, 40, 60), # QKV all reduce + ([1, 1, 32, 3584], 1, 3, 24, 24, 69), # FF1 all reduce ], ) -@pytest.mark.parametrize( - "tensor_mem_layout, per_chip_input_shape, dim, input_shard_shape,shard_grid,layout", - ( - ( # ReduceScatter After FF1/3 (~100 us) - ttnn.TensorMemoryLayout.WIDTH_SHARDED, - (1, 1, 32, 3840), - 3, - (32, 160), - ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 2))}), - ttnn.TILE_LAYOUT, - ), - ), -) -@pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) @pytest.mark.parametrize( "input_dtype", [ - ttnn.bfloat16, - # ttnn.bfloat8_b, + ttnn.bfloat8_b, ], ) @pytest.mark.parametrize( - "buffer_type", + "num_iters, warmup_iters", [ - ttnn.BufferType.L1, + (1000, 100), ], ) @pytest.mark.parametrize("enable_async", [True]) -@pytest.mark.parametrize("replication_factor", [8]) -@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) -@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) -def test_line_reduce_scatter_sharded_on_TG_rows_llama( - mesh_device, - num_devices, - per_chip_input_shape, - tensor_mem_layout, - input_shard_shape, - shard_grid, - shard_grid_orientation, - dim, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async, - replication_factor, - num_iters=10, -): - if len(mesh_device.get_devices()) != 32: - pytest.skip("Not TG!") - input_shard_spec = ttnn.ShardSpec( - shard_grid, - input_shard_shape, - shard_grid_orientation, - ) - - logger.warning("sharding not used due to issue #16699") - - run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( - mesh_device, - num_devices, - per_chip_input_shape, - ttnn.TensorMemoryLayout.INTERLEAVED, # tensor_mem_layout, - dim, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async=enable_async, - # input_shard_spec=input_shard_spec, - num_iters=num_iters, - num_reduce_scatter_instances=replication_factor, - cluster_axis=1, - use_reduce_scatter_async=True, - enable_persistent_fabric=True, - create_persistent_fabric=True, - teardown_persistent_fabric=True, - ) - - -# Enumerate the post-commit cases explicitly -@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize("trace_mode", [True]) @pytest.mark.parametrize( - "num_devices, num_links, per_chip_output_shape, layout", - [ - (4, 1, [1, 1, 32, 1280], ttnn.TILE_LAYOUT), # AllReduce after QKV (~110 us) - ], -) -@pytest.mark.parametrize( - "input_dtype", - [ - ttnn.bfloat16, - ], + "device_params", + [{"trace_region_size": 23887872}], + indirect=True, ) @pytest.mark.parametrize( - "buffer_type", + "mesh_device", [ - ttnn.BufferType.L1, + (8, 4), ], + indirect=True, ) -@pytest.mark.parametrize("replication_factor", [8]) # 1, 8]) -@pytest.mark.parametrize("enable_async", [True]) -@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) -@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) -def test_line_all_reduce_on_TG_rows_llama( +def test_all_reduce_tg_llama( mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, + output_shape, + cluster_axis, input_dtype, - layout, - buffer_type, + num_links, + input_num_cores, + output_num_cores, + num_iters, + warmup_iters, + perf_target_us, + enable_async, + trace_mode, use_program_cache, function_level_defaults, - enable_async, - replication_factor, - num_iters=10, ): - if len(mesh_device.get_devices()) != 32: - pytest.skip("Not TG!") - - logger.warning("sharding not used due to issue #16699") + profiler = BenchmarkProfiler() - run_all_reduce_with_mesh_tensor_along_row( + run_all_reduce_impl( mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, + output_shape, + cluster_axis, input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async=enable_async, + num_links, + input_num_cores, + output_num_cores, num_iters=num_iters, - num_all_reduce_instances=replication_factor, - cluster_axis=1, - enable_persistent_fabric=True, - create_persistent_fabric=True, - teardown_persistent_fabric=True, + warmup_iters=warmup_iters, + enable_async=enable_async, + trace_mode=trace_mode, + validate_all=False, + profiler=profiler, ) - -@skip_for_grayskull("Requires eth connected devices to run") -@pytest.mark.parametrize( - "num_devices, num_links, per_chip_output_shape, layout", - [ - (8, 1, [1, 1, 32, 2048], ttnn.TILE_LAYOUT), # AllReduce after DO and AllReduce after FF2 (~240 us) - # multi-links fail https://github.com/tenstorrent/tt-metal/issues/16699 - ], -) -@pytest.mark.parametrize( - "input_dtype", - [ - ttnn.bfloat16, - ], -) -@pytest.mark.parametrize( - "buffer_type", - [ - ttnn.BufferType.L1, - ], -) -@pytest.mark.parametrize("enable_async", [True]) -@pytest.mark.parametrize("replication_factor", [4]) -@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) -@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) -def test_line_all_reduce_on_TG_cols_llama( - mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async, - replication_factor, - num_iters=10, -): - if len(mesh_device.get_devices()) != 32: - pytest.skip("Not TG!") - - logger.warning("sharding not used due to issue #16699") - - run_all_reduce_with_mesh_tensor_along_row( - mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async=enable_async, - num_iters=num_iters, - num_all_reduce_instances=replication_factor, - cluster_axis=0, - enable_persistent_fabric=True, - create_persistent_fabric=True, - teardown_persistent_fabric=True, + time_taken = profiler.get_duration("all-reduce-async-trace") - profiler.get_duration( + "all-reduce-async-trace-warmup" ) + effective_iter = num_iters - warmup_iters + latency_us = time_taken / effective_iter * 1e6 + logger.info(f"Time taken: {time_taken} s") + logger.info(f"Time per iter: {latency_us} us") + if perf_target_us is not None: + assert ( + latency_us < perf_target_us + ), f"Measured latency {latency_us} us is greater than target {perf_target_us} us" diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py index 65fa2a49b73..501f04ff30e 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py @@ -33,7 +33,6 @@ def teardown_fabric_interface(mesh_device): def create_global_semaphore_with_same_address(mesh_device, cores, initial_value): semaphore_handles = ttnn.create_global_semaphore_with_same_address(mesh_device, cores, initial_value) addrs = ttnn.get_global_semaphore_address(semaphore_handles) - logger.debug(f"from remote semaphore handle addresses: {addrs}") # assert all addresses are the same assert len(set(addrs)) == 1 return semaphore_handles diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 405b9106f3a..4182ef97dc6 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -20,6 +20,7 @@ num_cores_to_rectangle_grid, round_up, ) +from models.perf.benchmarking_utils import BenchmarkProfiler def run_all_reduce_impl( @@ -31,8 +32,11 @@ def run_all_reduce_impl( input_num_cores, output_num_cores, num_iters=1, + warmup_iters=0, enable_async=False, trace_mode=False, + validate_all=True, + profiler=BenchmarkProfiler(), ): cluster_shape = (8, 4) @@ -141,9 +145,9 @@ def run_all_reduce_impl( ##### Run the op ################################## - def run_op(): + def run_op(n_iters, store_all_results=True): outs = [] - for i in range(num_iters): + for i in range(n_iters): out = ttnn.experimental.all_reduce_async( tt_input_tensor, cluster_axis=cluster_axis, @@ -157,43 +161,65 @@ def run_op(): if not trace_mode: for d in mesh_device.get_devices(): ttnn.synchronize_device(d) - outs.append(out) + if store_all_results: + outs.append(out) - return outs + if store_all_results: + return outs + else: + return [out] if trace_mode: ##### Compile Model ##### logger.info("Compiling model") - tt_outs = run_op() + tt_outs = run_op(num_iters, store_all_results=validate_all) ##### Capture Trace ##### logger.info("Capturing trace") + if warmup_iters > 0: + trace_id_warmup = ttnn.begin_trace_capture(mesh_device, cq_id=0) + tt_outs = run_op(warmup_iters, store_all_results=validate_all) + ttnn.end_trace_capture(mesh_device, trace_id_warmup, cq_id=0) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) - tt_outs = run_op() + tt_outs = run_op(num_iters, store_all_results=validate_all) ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) ##### Run Trace ##### logger.info("Starting Trace perf test...") - time_start = time() + profiler.start("all-reduce-async-trace-warmup") + if warmup_iters > 0: + ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False) + ttnn.release_trace(mesh_device, trace_id_warmup) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + profiler.end("all-reduce-async-trace-warmup") + + profiler.start("all-reduce-async-trace") ttnn.execute_trace(mesh_device, trace_id, blocking=False) ttnn.release_trace(mesh_device, trace_id) for d in mesh_device.get_devices(): ttnn.synchronize_device(d) - time_end = time() - logger.info(f"Time taken: {time_end - time_start} s") - logger.info(f"Time per iter: {(time_end - time_start) / num_iters} s") - logger.info(f"Time per iter: {(time_end - time_start) / num_iters * 1e6} us") + profiler.end("all-reduce-async-trace") + time_taken = profiler.get_duration("all-reduce-async-trace") - profiler.get_duration( + "all-reduce-async-trace-warmup" + ) + effective_iter = num_iters - warmup_iters + logger.info(f"Time taken: {time_taken} s") + logger.info(f"Time per iter: {time_taken / effective_iter} s") + logger.info(f"Time per iter: {time_taken / effective_iter * 1e6} us") else: - tt_outs = run_op() + tt_outs = run_op(num_iters, store_all_results=validate_all) ################################## ##### Validation ################################## - for tensor_index in range(len(tt_outs)): - tt_out_tensor = tt_outs[tensor_index] - output_tensor = output_tensor_goldens_list[tensor_index] + def validate(tt_out_tensor, output_tensor): for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): # get_device_tensors returns row major, so we need to select the correct golden tensor if cluster_axis == 0: @@ -209,7 +235,17 @@ def run_op(): else: eq, output = comp_pcc(tt_output_tensor, output_tensor_) assert eq, f"{i} FAILED: {output}" - logger.info(f"PCC output for {tensor_index} is: {output}") + logger.info(f"PCC output is: {output}") + + if validate_all: + for tensor_index in range(len(tt_outs)): + tt_out_tensor = tt_outs[tensor_index] + output_tensor = output_tensor_goldens_list[tensor_index] + validate(tt_out_tensor, output_tensor) + else: + tt_out_tensor = tt_outs[-1] + output_tensor = output_tensor_goldens_list[-1] + validate(tt_out_tensor, output_tensor) for i in range(mesh_device.get_num_devices()): assert ( @@ -249,7 +285,12 @@ def run_op(): ttnn.bfloat8_b, ], ) -@pytest.mark.parametrize("num_iters", [5]) +@pytest.mark.parametrize( + "num_iters, warmup_iters", + [ + (1000, 100), + ], +) @pytest.mark.parametrize("enable_async", [True]) @pytest.mark.parametrize("trace_mode", [True]) @pytest.mark.parametrize( @@ -273,11 +314,14 @@ def test_all_reduce( input_num_cores, output_num_cores, num_iters, + warmup_iters, enable_async, trace_mode, use_program_cache, function_level_defaults, ): + profiler = BenchmarkProfiler() + run_all_reduce_impl( mesh_device, output_shape, @@ -287,6 +331,17 @@ def test_all_reduce( input_num_cores, output_num_cores, num_iters=num_iters, + warmup_iters=warmup_iters, enable_async=enable_async, trace_mode=trace_mode, + validate_all=False, + profiler=profiler, + ) + + time_taken = profiler.get_duration("all-reduce-async-trace") - profiler.get_duration( + "all-reduce-async-trace-warmup" ) + effective_iter = num_iters - warmup_iters + latency_us = time_taken / effective_iter * 1e6 + logger.info(f"Time taken: {time_taken} s") + logger.info(f"Time per iter: {latency_us} us") diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp index a68fd52889b..58136a8232f 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp @@ -4,7 +4,6 @@ #include #include "compute_kernel_api/eltwise_binary.h" -#include "compute_kernel_api/tile_move_copy.h" namespace NAMESPACE { void MAIN { diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp index bb25661536d..f657ae6f065 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp @@ -4,6 +4,8 @@ #include "dataflow_api.h" #include +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" #include "cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_ccl_common.hpp" #include #include @@ -133,13 +135,12 @@ void kernel_main() { // 2. mcast output ready semaphore auto* pkt_hdr = reinterpret_cast(packet_header_buffer_seminc); - pkt_hdr->to_atomic_inc(); + uint64_t out_ready_sem_noc_addr_in_pkt = + safe_get_noc_addr(out_ready_sem_noc0_x, out_ready_sem_noc0_y, out_ready_sem_bank_addr, 0); pkt_hdr->to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader{ - out_ready_sem_bank_addr, + out_ready_sem_noc_addr_in_pkt, static_cast(1), // increment 1 - 32, - static_cast(out_ready_sem_noc0_x), - static_cast(out_ready_sem_noc0_y)}); + 32}); // Write the mcast packet (forward) if (fabric_connection.has_forward_connection()) { fabric_connection.get_forward_connection().wait_for_empty_write_slot(); From b779894d2392e202d1acb86a2e3820b145c1f0aa Mon Sep 17 00:00:00 2001 From: Johanna Rock Date: Thu, 6 Feb 2025 13:05:03 +0000 Subject: [PATCH 25/25] Extend llama sharded all gather for LN --- .../operations/ccl/test_ccl_async_TG_llama.py | 20 +++++++++- .../device/all_gather_async_op.cpp | 38 +++++++++++++------ .../device/all_gather_async_op.hpp | 4 +- ..._gather_async_program_minimal_variants.cpp | 24 ++++++------ ...er.cpp => llama_shapes_sharded_reader.cpp} | 0 ...er.cpp => llama_shapes_sharded_writer.cpp} | 0 6 files changed, 61 insertions(+), 25 deletions(-) rename ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/{llama_post_binary_matmul_shape_reader.cpp => llama_shapes_sharded_reader.cpp} (100%) rename ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/{llama_post_binary_matmul_shape_writer.cpp => llama_shapes_sharded_writer.cpp} (100%) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py index 0d6795a9551..5c726693543 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py @@ -74,6 +74,13 @@ def get_core_range_set(output_core_grid): return output_core_range_set +CORE_RANGE_SET_1x1 = ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0)), + } +) + + # Enumerate the post-commit cases explicitly @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( @@ -125,13 +132,24 @@ def get_core_range_set(output_core_grid): ttnn.TILE_LAYOUT, 25, ), + ( # AllGather for layernorm + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + (1, 1, 32, 128), + 3, + (32, 32), + CORE_RANGE_SET_1x1, + (32, 128), + CORE_RANGE_SET_1x1, + ttnn.TILE_LAYOUT, + 13, + ), ), ) @pytest.mark.parametrize("replication_factor", [8]) @pytest.mark.parametrize("enable_async", [True]) @pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) @pytest.mark.parametrize("device_params", [{"trace_region_size": 17068032}], indirect=True) -def test_all_gather_llama( +def test_all_gather_tg_llama( mesh_device, num_devices, output_shape, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp index f295d317f64..aca9f3e5def 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp @@ -145,7 +145,7 @@ AllGatherAsyncVersion AllGatherAsync::select_version(const Tensor& input_tensor) log_trace(tt::LogOp, "[select_version] output_is_sharded: {}", output_is_sharded); if (input_is_sharded && output_is_sharded) { - // Check for first llama post binary matmul case + // Check for llama post binary mult+silu case if (input_tensor_shape[0] == 1 && input_tensor_shape[1] == 1 && input_tensor_shape[2] == 32 && input_tensor_shape[3] == 960 && input_tensor_memory_config.buffer_type == BufferType::L1 && output_mem_config.buffer_type == BufferType::L1 && @@ -156,10 +156,13 @@ AllGatherAsyncVersion AllGatherAsync::select_version(const Tensor& input_tensor) output_mem_config.shard_spec->shape[0] == 32 && output_mem_config.shard_spec->shape[1] == 160 && input_shard_num_cores == 30 && output_shard_num_cores == 24) { - return AllGatherAsyncVersion::LLAMA_POST_BINARY_MATMUL; + log_trace( + tt::LogOp, + "Matching conditions for Llama post binary mult+silu, using LLAMA_MINIMAL_SHARDED implementation"); + return AllGatherAsyncVersion::LLAMA_MINIMAL_SHARDED; } - // Check for second llama post binary matmul case + // Check for llama post SDPA case if (input_tensor_shape[0] == 1 && input_tensor_shape[1] == 8 && input_tensor_shape[2] == 32 && input_tensor_shape[3] == 128 && input_tensor_memory_config.buffer_type == BufferType::L1 && output_mem_config.buffer_type == BufferType::L1 && @@ -170,11 +173,26 @@ AllGatherAsyncVersion AllGatherAsync::select_version(const Tensor& input_tensor) output_mem_config.shard_spec->shape[0] == 32 && output_mem_config.shard_spec->shape[1] == 128 && input_shard_num_cores == 8 && output_shard_num_cores == 32) { - log_trace(tt::LogOp, "All conditions matched for LLAMA_POST_BINARY_MATMUL case"); - return AllGatherAsyncVersion::LLAMA_POST_BINARY_MATMUL; + log_trace(tt::LogOp, "Matching conditions for Llama post SDPA, using LLAMA_MINIMAL_SHARDED implementation"); + return AllGatherAsyncVersion::LLAMA_MINIMAL_SHARDED; + } + + // Check for llama rms norm case + if (input_tensor_shape[0] == 1 && input_tensor_shape[1] == 1 && input_tensor_shape[2] == 32 && + input_tensor_shape[3] == 32 && input_tensor_memory_config.buffer_type == BufferType::L1 && + output_mem_config.buffer_type == BufferType::L1 && + input_tensor_memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED && + output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED && + input_tensor_memory_config.shard_spec->shape[0] == 32 && + input_tensor_memory_config.shard_spec->shape[1] == 32 && output_mem_config.shard_spec->shape[0] == 32 && + output_mem_config.shard_spec->shape[1] == 128 && input_shard_num_cores == 1 && + output_shard_num_cores == 1) { + log_trace( + tt::LogOp, "Matching conditions for Llama rms norm case, using LLAMA_MINIMAL_SHARDED implementation"); + return AllGatherAsyncVersion::LLAMA_MINIMAL_SHARDED; } } - log_trace(tt::LogOp, "All conditions matched for generic case"); + log_trace(tt::LogOp, "Using generic implementation"); return AllGatherAsyncVersion::GENERIC; } @@ -206,11 +224,9 @@ operation::ProgramWithCallbacks AllGatherAsync::create_program( this->sub_device_id, this->enable_persistent_fabric_mode); - case AllGatherAsyncVersion::LLAMA_POST_BINARY_MATMUL: - log_trace( - tt::LogOp, - "Detected all gather specialized shape. all_gather_async_llama_post_binary_matmul is called"); - return all_gather_async_llama_post_binary_matmul( + case AllGatherAsyncVersion::LLAMA_MINIMAL_SHARDED: + log_trace(tt::LogOp, "Detected all gather specialized shape. all_gather_async_llama_sharded is called"); + return all_gather_async_llama_sharded( input_tensors[0], this->forward_device, this->backward_device, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp index 853c259e4f6..49b272f07c8 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp @@ -28,7 +28,7 @@ using ccl::EriscDatamoverBuilder; enum class AllGatherAsyncVersion { GENERIC = 0, MINIMAL_INTERLEAVED_32 = 1, - LLAMA_POST_BINARY_MATMUL = 2, + LLAMA_MINIMAL_SHARDED = 2, }; struct AllGatherAsync { @@ -142,7 +142,7 @@ operation::ProgramWithCallbacks all_gather_async_minimal_interleaved_dim3_1_1_32 const GlobalSemaphore& semaphore, const std::optional& sub_device_id, bool enable_persistent_fabric_mode); -operation::ProgramWithCallbacks all_gather_async_llama_post_binary_matmul( +operation::ProgramWithCallbacks all_gather_async_llama_sharded( const Tensor& input_tensor, std::optional forward_device, std::optional backward_device, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program_minimal_variants.cpp index ba8edc57bf6..49d15fce163 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program_minimal_variants.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program_minimal_variants.cpp @@ -275,7 +275,7 @@ operation::ProgramWithCallbacks all_gather_async_minimal_interleaved_dim3_1_1_32 return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } -operation::ProgramWithCallbacks all_gather_async_llama_post_binary_matmul( +operation::ProgramWithCallbacks all_gather_async_llama_sharded( const Tensor& input_tensor, std::optional forward_device, std::optional backward_device, @@ -291,8 +291,7 @@ operation::ProgramWithCallbacks all_gather_async_llama_post_binary_matmul( tt::tt_metal::Program program{}; const bool enable_async_output_tensor = false; TT_FATAL( - enable_persistent_fabric_mode, - "only persistent fabric mode is supported for all_gather_async_llama_post_binary_matmul"); + enable_persistent_fabric_mode, "only persistent fabric mode is supported for all_gather_async_llama_sharded"); IDevice* device = input_tensor.device(); bool is_first_chip = ring_index == 0; @@ -385,7 +384,7 @@ operation::ProgramWithCallbacks all_gather_async_llama_post_binary_matmul( auto worker_sender_reader_kernel_id = tt::tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" - "llama_post_binary_matmul_shape_reader.cpp", + "llama_shapes_sharded_reader.cpp", sender_worker_core_range, reader_kernel_config); @@ -408,7 +407,7 @@ operation::ProgramWithCallbacks all_gather_async_llama_post_binary_matmul( auto worker_sender_writer_kernel_id = tt::tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" - "llama_post_binary_matmul_shape_writer.cpp", + "llama_shapes_sharded_writer.cpp", sender_worker_core_range, writer_kernel_config); @@ -417,14 +416,17 @@ operation::ProgramWithCallbacks all_gather_async_llama_post_binary_matmul( // semaphore auto input_cores_vec = corerange_to_cores(input_tensor_cores, std::nullopt, true); auto output_cores_vec = corerange_to_cores(output_tensor_cores, std::nullopt, true); - auto cores_per_device = output_cores_vec.size() / ring_size; + auto cores_per_device = output_cores_vec.size() + ring_size - 1 / ring_size; + uint32_t start_core_index_for_device = output_cores_vec.size() / (float)ring_size * ring_index; + uint32_t end_core_index_for_device = start_core_index_for_device + cores_per_device; TT_FATAL( - output_cores_vec.size() % ring_size == 0, - "output sharded cores must be divisible by num_links for this work distribution scheme"); + output_cores_vec.size() % ring_size == 0 || output_cores_vec.size() == 1, + "output sharded cores ( {} ) must be divisible by num_links ( {} ) or 1 for this work distribution scheme", + output_cores_vec.size(), + ring_size); auto output_cores_this_device = std::vector( - output_cores_vec.begin() + ring_index * cores_per_device, - output_cores_vec.begin() + (ring_index + 1) * cores_per_device); - + output_cores_vec.begin() + start_core_index_for_device, output_cores_vec.begin() + end_core_index_for_device); + log_trace(tt::LogOp, "output_cores_this_device: {}", output_cores_this_device); for (uint32_t link = 0; link < num_links; link++) { CoreCoord core = sender_worker_cores[link]; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_shapes_sharded_reader.cpp similarity index 100% rename from ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp rename to ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_shapes_sharded_reader.cpp diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_shapes_sharded_writer.cpp similarity index 100% rename from ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp rename to ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_shapes_sharded_writer.cpp