From bcf48a620040e2f49013b76212880072edff8ee5 Mon Sep 17 00:00:00 2001 From: Jack Cai Date: Mon, 10 Feb 2025 22:23:29 +0000 Subject: [PATCH] added all reduce into llama ccl perf test and added proper measurement of e2e trace perf --- .../ccl/test_all_gather_TG_post_commit.py | 90 ++++--- .../operations/ccl/test_ccl_async_TG_llama.py | 255 +++++------------- .../operations/ccl/test_ccl_common.py | 1 - .../operations/ccl/test_new_all_reduce.py | 89 ++++-- .../device/kernels/compute/reduction.cpp | 1 - .../device/kernels/dataflow/worker_writer.cpp | 11 +- 6 files changed, 195 insertions(+), 252 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py index 7f37600028a..2dcd7473d0c 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py @@ -64,6 +64,7 @@ def run_with_trace( n_worker=None, n_buffer=None, num_iter=20, + warmup_iters=0, use_all_gather_async=False, profiler=BenchmarkProfiler(), ): @@ -99,49 +100,68 @@ def run_with_trace( # Capture trace logger.info("Capturing trace") - trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) - for i in range(num_iter): - if use_all_gather_async: - logger.info("Running all-gather async") - tt_out_tensor = ttnn.experimental.all_gather_async( - input_tensor, - dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - topology=ttnn.Topology.Linear, - multi_device_global_semaphore=ccl_semaphore_handles[i] - if type(ccl_semaphore_handles) == list - else ccl_semaphore_handles, - num_links=num_links, - memory_config=output_mem_config, - subdevice_id=worker_sub_device_id, - enable_persistent_fabric_mode=enable_persistent_fabric, - ) - else: - tt_out_tensor = ttnn.all_gather( - input_tensor, - dim=dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - num_links=num_links, - memory_config=output_mem_config, - topology=all_gather_topology, - ) - ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) - for d in mesh_device.get_devices(): - ttnn.synchronize_device(d) + + def capture_trace(n_iters): + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + for i in range(n_iters): + if use_all_gather_async: + logger.info("Running all-gather async") + tt_out_tensor = ttnn.experimental.all_gather_async( + input_tensor, + dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + topology=ttnn.Topology.Linear, + multi_device_global_semaphore=ccl_semaphore_handles[i] + if type(ccl_semaphore_handles) == list + else ccl_semaphore_handles, + num_links=num_links, + memory_config=output_mem_config, + subdevice_id=worker_sub_device_id, + enable_persistent_fabric_mode=enable_persistent_fabric, + ) + else: + tt_out_tensor = ttnn.all_gather( + input_tensor, + dim=dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + num_links=num_links, + memory_config=output_mem_config, + topology=all_gather_topology, + ) + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + return trace_id + + if warmup_iters > 0: + trace_id_warmup = capture_trace(warmup_iters) + trace_id = capture_trace(num_iter) # Run the op logger.info("Starting Trace perf test...") + profiler.start("all-gather-async-trace-warmup") + if warmup_iters > 0: + ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False) + ttnn.release_trace(mesh_device, trace_id_warmup) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + profiler.end("all-gather-async-trace-warmup") + profiler.start("all-gather-async-trace") ttnn.execute_trace(mesh_device, trace_id, blocking=False) ttnn.release_trace(mesh_device, trace_id) for d in mesh_device.get_devices(): ttnn.synchronize_device(d) profiler.end("all-gather-async-trace") - logger.info(f"Time taken: {profiler.get_duration('all-gather-async-trace')} s") - logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter} s") - logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter * 1e6} us") + time_taken = profiler.get_duration("all-gather-async-trace") - profiler.get_duration( + "all-gather-async-trace-warmup" + ) + effective_iter = num_iter - warmup_iters + logger.info(f"Time taken: {time_taken} s") + logger.info(f"Time per iter: {time_taken / effective_iter} s") + logger.info(f"Time per iter: {time_taken / effective_iter * 1e6} us") return tt_out_tensor @@ -163,6 +183,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( output_shard_spec: ttnn.ShardSpec = None, num_all_gather_instances: int = 1, num_iters: int = 1, + warmup_iters: int = 0, cluster_axis: int = 0, tile=(32, 32), trace_mode=False, @@ -277,6 +298,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( enable_persistent_fabric=enable_persistent_fabric, all_gather_topology=ttnn.Topology.Linear, num_iter=num_iters, + warmup_iters=warmup_iters, use_all_gather_async=use_all_gather_async, profiler=profiler, ) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py index fe967467e14..0d6795a9551 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py @@ -20,8 +20,8 @@ from tests.ttnn.unit_tests.operations.ccl.test_reduce_scatter_TG_nightly import ( run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows, ) -from tests.ttnn.unit_tests.operations.ccl.test_all_reduce_async import ( - run_all_reduce_with_mesh_tensor_along_row, +from tests.ttnn.unit_tests.operations.ccl.test_new_all_reduce import ( + run_all_reduce_impl, ) from models.perf.benchmarking_utils import BenchmarkProfiler @@ -89,9 +89,9 @@ def get_core_range_set(output_core_grid): ], ) @pytest.mark.parametrize( - "num_iters", + "num_iters, warmup_iters", [ - 5000, + (500, 100), ], ) @pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) @@ -131,7 +131,7 @@ def get_core_range_set(output_core_grid): @pytest.mark.parametrize("enable_async", [True]) @pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) @pytest.mark.parametrize("device_params", [{"trace_region_size": 17068032}], indirect=True) -def test_line_all_gather_sharded_on_TG_rows_llama( +def test_all_gather_llama( mesh_device, num_devices, output_shape, @@ -150,6 +150,7 @@ def test_line_all_gather_sharded_on_TG_rows_llama( enable_async, replication_factor, num_iters, + warmup_iters, perf_target_us, ): if len(mesh_device.get_devices()) != 32: @@ -185,6 +186,7 @@ def test_line_all_gather_sharded_on_TG_rows_llama( function_level_defaults, enable_async=enable_async, num_iters=num_iters, + warmup_iters=warmup_iters, input_shard_spec=input_shard_spec, output_shard_spec=output_shard_spec, num_all_gather_instances=replication_factor, @@ -197,7 +199,13 @@ def test_line_all_gather_sharded_on_TG_rows_llama( teardown_persistent_fabric=True, ) - latency_us = profiler.get_duration("all-gather-async-trace") / num_iters * 1e6 + time_taken = profiler.get_duration("all-gather-async-trace") - profiler.get_duration( + "all-gather-async-trace-warmup" + ) + effective_iter = num_iters - warmup_iters + latency_us = time_taken / effective_iter * 1e6 + logger.info(f"Time taken: {time_taken} s") + logger.info(f"Time per iter: {latency_us} us") if perf_target_us is not None: assert ( latency_us < perf_target_us @@ -206,222 +214,81 @@ def test_line_all_gather_sharded_on_TG_rows_llama( @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( - "num_devices, num_links", + "output_shape, cluster_axis, num_links, input_num_cores, output_num_cores, perf_target_us", [ - (4, 2), + ([1, 1, 32, 2048], 0, 4, 24, 16, 84), # FF2/DO all reduce + ([1, 1, 32, 1280], 1, 3, 24, 40, 60), # QKV all reduce + ([1, 1, 32, 3584], 1, 3, 24, 24, 69), # FF1 all reduce ], ) -@pytest.mark.parametrize( - "tensor_mem_layout, per_chip_input_shape, dim, input_shard_shape,shard_grid,layout", - ( - ( # ReduceScatter After FF1/3 (~100 us) - ttnn.TensorMemoryLayout.WIDTH_SHARDED, - (1, 1, 32, 3840), - 3, - (32, 160), - ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 2))}), - ttnn.TILE_LAYOUT, - ), - ), -) -@pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) @pytest.mark.parametrize( "input_dtype", [ - ttnn.bfloat16, - # ttnn.bfloat8_b, + ttnn.bfloat8_b, ], ) @pytest.mark.parametrize( - "buffer_type", + "num_iters, warmup_iters", [ - ttnn.BufferType.L1, + (1000, 100), ], ) @pytest.mark.parametrize("enable_async", [True]) -@pytest.mark.parametrize("replication_factor", [8]) -@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) -@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) -def test_line_reduce_scatter_sharded_on_TG_rows_llama( - mesh_device, - num_devices, - per_chip_input_shape, - tensor_mem_layout, - input_shard_shape, - shard_grid, - shard_grid_orientation, - dim, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async, - replication_factor, - num_iters=10, -): - if len(mesh_device.get_devices()) != 32: - pytest.skip("Not TG!") - input_shard_spec = ttnn.ShardSpec( - shard_grid, - input_shard_shape, - shard_grid_orientation, - ) - - logger.warning("sharding not used due to issue #16699") - - run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( - mesh_device, - num_devices, - per_chip_input_shape, - ttnn.TensorMemoryLayout.INTERLEAVED, # tensor_mem_layout, - dim, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async=enable_async, - # input_shard_spec=input_shard_spec, - num_iters=num_iters, - num_reduce_scatter_instances=replication_factor, - cluster_axis=1, - use_reduce_scatter_async=True, - enable_persistent_fabric=True, - create_persistent_fabric=True, - teardown_persistent_fabric=True, - ) - - -# Enumerate the post-commit cases explicitly -@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize("trace_mode", [True]) @pytest.mark.parametrize( - "num_devices, num_links, per_chip_output_shape, layout", - [ - (4, 1, [1, 1, 32, 1280], ttnn.TILE_LAYOUT), # AllReduce after QKV (~110 us) - ], -) -@pytest.mark.parametrize( - "input_dtype", - [ - ttnn.bfloat16, - ], + "device_params", + [{"trace_region_size": 23887872}], + indirect=True, ) @pytest.mark.parametrize( - "buffer_type", + "mesh_device", [ - ttnn.BufferType.L1, + (8, 4), ], + indirect=True, ) -@pytest.mark.parametrize("replication_factor", [8]) # 1, 8]) -@pytest.mark.parametrize("enable_async", [True]) -@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) -@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) -def test_line_all_reduce_on_TG_rows_llama( +def test_all_reduce_tg_llama( mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, + output_shape, + cluster_axis, input_dtype, - layout, - buffer_type, + num_links, + input_num_cores, + output_num_cores, + num_iters, + warmup_iters, + perf_target_us, + enable_async, + trace_mode, use_program_cache, function_level_defaults, - enable_async, - replication_factor, - num_iters=10, ): - if len(mesh_device.get_devices()) != 32: - pytest.skip("Not TG!") - - logger.warning("sharding not used due to issue #16699") + profiler = BenchmarkProfiler() - run_all_reduce_with_mesh_tensor_along_row( + run_all_reduce_impl( mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, + output_shape, + cluster_axis, input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async=enable_async, + num_links, + input_num_cores, + output_num_cores, num_iters=num_iters, - num_all_reduce_instances=replication_factor, - cluster_axis=1, - enable_persistent_fabric=True, - create_persistent_fabric=True, - teardown_persistent_fabric=True, + warmup_iters=warmup_iters, + enable_async=enable_async, + trace_mode=trace_mode, + validate_all=False, + profiler=profiler, ) - -@skip_for_grayskull("Requires eth connected devices to run") -@pytest.mark.parametrize( - "num_devices, num_links, per_chip_output_shape, layout", - [ - (8, 1, [1, 1, 32, 2048], ttnn.TILE_LAYOUT), # AllReduce after DO and AllReduce after FF2 (~240 us) - # multi-links fail https://github.com/tenstorrent/tt-metal/issues/16699 - ], -) -@pytest.mark.parametrize( - "input_dtype", - [ - ttnn.bfloat16, - ], -) -@pytest.mark.parametrize( - "buffer_type", - [ - ttnn.BufferType.L1, - ], -) -@pytest.mark.parametrize("enable_async", [True]) -@pytest.mark.parametrize("replication_factor", [4]) -@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) -@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) -def test_line_all_reduce_on_TG_cols_llama( - mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async, - replication_factor, - num_iters=10, -): - if len(mesh_device.get_devices()) != 32: - pytest.skip("Not TG!") - - logger.warning("sharding not used due to issue #16699") - - run_all_reduce_with_mesh_tensor_along_row( - mesh_device, - num_devices, - per_chip_output_shape, - num_links, - math_op, - input_dtype, - layout, - buffer_type, - use_program_cache, - function_level_defaults, - enable_async=enable_async, - num_iters=num_iters, - num_all_reduce_instances=replication_factor, - cluster_axis=0, - enable_persistent_fabric=True, - create_persistent_fabric=True, - teardown_persistent_fabric=True, + time_taken = profiler.get_duration("all-reduce-async-trace") - profiler.get_duration( + "all-reduce-async-trace-warmup" ) + effective_iter = num_iters - warmup_iters + latency_us = time_taken / effective_iter * 1e6 + logger.info(f"Time taken: {time_taken} s") + logger.info(f"Time per iter: {latency_us} us") + if perf_target_us is not None: + assert ( + latency_us < perf_target_us + ), f"Measured latency {latency_us} us is greater than target {perf_target_us} us" diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py index 65fa2a49b73..501f04ff30e 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py @@ -33,7 +33,6 @@ def teardown_fabric_interface(mesh_device): def create_global_semaphore_with_same_address(mesh_device, cores, initial_value): semaphore_handles = ttnn.create_global_semaphore_with_same_address(mesh_device, cores, initial_value) addrs = ttnn.get_global_semaphore_address(semaphore_handles) - logger.debug(f"from remote semaphore handle addresses: {addrs}") # assert all addresses are the same assert len(set(addrs)) == 1 return semaphore_handles diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 405b9106f3a..4182ef97dc6 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -20,6 +20,7 @@ num_cores_to_rectangle_grid, round_up, ) +from models.perf.benchmarking_utils import BenchmarkProfiler def run_all_reduce_impl( @@ -31,8 +32,11 @@ def run_all_reduce_impl( input_num_cores, output_num_cores, num_iters=1, + warmup_iters=0, enable_async=False, trace_mode=False, + validate_all=True, + profiler=BenchmarkProfiler(), ): cluster_shape = (8, 4) @@ -141,9 +145,9 @@ def run_all_reduce_impl( ##### Run the op ################################## - def run_op(): + def run_op(n_iters, store_all_results=True): outs = [] - for i in range(num_iters): + for i in range(n_iters): out = ttnn.experimental.all_reduce_async( tt_input_tensor, cluster_axis=cluster_axis, @@ -157,43 +161,65 @@ def run_op(): if not trace_mode: for d in mesh_device.get_devices(): ttnn.synchronize_device(d) - outs.append(out) + if store_all_results: + outs.append(out) - return outs + if store_all_results: + return outs + else: + return [out] if trace_mode: ##### Compile Model ##### logger.info("Compiling model") - tt_outs = run_op() + tt_outs = run_op(num_iters, store_all_results=validate_all) ##### Capture Trace ##### logger.info("Capturing trace") + if warmup_iters > 0: + trace_id_warmup = ttnn.begin_trace_capture(mesh_device, cq_id=0) + tt_outs = run_op(warmup_iters, store_all_results=validate_all) + ttnn.end_trace_capture(mesh_device, trace_id_warmup, cq_id=0) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) - tt_outs = run_op() + tt_outs = run_op(num_iters, store_all_results=validate_all) ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) ##### Run Trace ##### logger.info("Starting Trace perf test...") - time_start = time() + profiler.start("all-reduce-async-trace-warmup") + if warmup_iters > 0: + ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False) + ttnn.release_trace(mesh_device, trace_id_warmup) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + profiler.end("all-reduce-async-trace-warmup") + + profiler.start("all-reduce-async-trace") ttnn.execute_trace(mesh_device, trace_id, blocking=False) ttnn.release_trace(mesh_device, trace_id) for d in mesh_device.get_devices(): ttnn.synchronize_device(d) - time_end = time() - logger.info(f"Time taken: {time_end - time_start} s") - logger.info(f"Time per iter: {(time_end - time_start) / num_iters} s") - logger.info(f"Time per iter: {(time_end - time_start) / num_iters * 1e6} us") + profiler.end("all-reduce-async-trace") + time_taken = profiler.get_duration("all-reduce-async-trace") - profiler.get_duration( + "all-reduce-async-trace-warmup" + ) + effective_iter = num_iters - warmup_iters + logger.info(f"Time taken: {time_taken} s") + logger.info(f"Time per iter: {time_taken / effective_iter} s") + logger.info(f"Time per iter: {time_taken / effective_iter * 1e6} us") else: - tt_outs = run_op() + tt_outs = run_op(num_iters, store_all_results=validate_all) ################################## ##### Validation ################################## - for tensor_index in range(len(tt_outs)): - tt_out_tensor = tt_outs[tensor_index] - output_tensor = output_tensor_goldens_list[tensor_index] + def validate(tt_out_tensor, output_tensor): for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): # get_device_tensors returns row major, so we need to select the correct golden tensor if cluster_axis == 0: @@ -209,7 +235,17 @@ def run_op(): else: eq, output = comp_pcc(tt_output_tensor, output_tensor_) assert eq, f"{i} FAILED: {output}" - logger.info(f"PCC output for {tensor_index} is: {output}") + logger.info(f"PCC output is: {output}") + + if validate_all: + for tensor_index in range(len(tt_outs)): + tt_out_tensor = tt_outs[tensor_index] + output_tensor = output_tensor_goldens_list[tensor_index] + validate(tt_out_tensor, output_tensor) + else: + tt_out_tensor = tt_outs[-1] + output_tensor = output_tensor_goldens_list[-1] + validate(tt_out_tensor, output_tensor) for i in range(mesh_device.get_num_devices()): assert ( @@ -249,7 +285,12 @@ def run_op(): ttnn.bfloat8_b, ], ) -@pytest.mark.parametrize("num_iters", [5]) +@pytest.mark.parametrize( + "num_iters, warmup_iters", + [ + (1000, 100), + ], +) @pytest.mark.parametrize("enable_async", [True]) @pytest.mark.parametrize("trace_mode", [True]) @pytest.mark.parametrize( @@ -273,11 +314,14 @@ def test_all_reduce( input_num_cores, output_num_cores, num_iters, + warmup_iters, enable_async, trace_mode, use_program_cache, function_level_defaults, ): + profiler = BenchmarkProfiler() + run_all_reduce_impl( mesh_device, output_shape, @@ -287,6 +331,17 @@ def test_all_reduce( input_num_cores, output_num_cores, num_iters=num_iters, + warmup_iters=warmup_iters, enable_async=enable_async, trace_mode=trace_mode, + validate_all=False, + profiler=profiler, + ) + + time_taken = profiler.get_duration("all-reduce-async-trace") - profiler.get_duration( + "all-reduce-async-trace-warmup" ) + effective_iter = num_iters - warmup_iters + latency_us = time_taken / effective_iter * 1e6 + logger.info(f"Time taken: {time_taken} s") + logger.info(f"Time per iter: {latency_us} us") diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp index a68fd52889b..58136a8232f 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/compute/reduction.cpp @@ -4,7 +4,6 @@ #include #include "compute_kernel_api/eltwise_binary.h" -#include "compute_kernel_api/tile_move_copy.h" namespace NAMESPACE { void MAIN { diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp index bb25661536d..f657ae6f065 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/kernels/dataflow/worker_writer.cpp @@ -4,6 +4,8 @@ #include "dataflow_api.h" #include +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" #include "cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_ccl_common.hpp" #include #include @@ -133,13 +135,12 @@ void kernel_main() { // 2. mcast output ready semaphore auto* pkt_hdr = reinterpret_cast(packet_header_buffer_seminc); - pkt_hdr->to_atomic_inc(); + uint64_t out_ready_sem_noc_addr_in_pkt = + safe_get_noc_addr(out_ready_sem_noc0_x, out_ready_sem_noc0_y, out_ready_sem_bank_addr, 0); pkt_hdr->to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader{ - out_ready_sem_bank_addr, + out_ready_sem_noc_addr_in_pkt, static_cast(1), // increment 1 - 32, - static_cast(out_ready_sem_noc0_x), - static_cast(out_ready_sem_noc0_y)}); + 32}); // Write the mcast packet (forward) if (fabric_connection.has_forward_connection()) { fabric_connection.get_forward_connection().wait_for_empty_write_slot();