Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for minimal all_reduce for Llama shapes #17792

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
503ca1c
added ccl async llama perf and correctness measurement tests to CI
caixunshiren Feb 7, 2025
ccd3f98
updated ccl perf target for llama all gather async
caixunshiren Feb 10, 2025
0c9ccd7
WIP adding infra for the new all reduce.
avoraTT Feb 3, 2025
c2a11de
interleaved all gather works with good PCC. Next step: add reduction …
avoraTT Feb 3, 2025
4b67c01
wip reduction stuff.
avoraTT Feb 3, 2025
ce9e52d
#0: added noc semaphore multicast in writer, seeing a hang on noc_sem…
kpaigwar Feb 4, 2025
bf9cb47
Add fix for reduction worker hang.
avoraTT Feb 4, 2025
600352e
Add reduction and output cb. Currently, the reduction kernel does a c…
avoraTT Feb 4, 2025
5d4d879
All-reduce for FF1/FF3 works
avoraTT Feb 4, 2025
27414ef
Add support for reshard. TODO: add support to drop padding from input…
avoraTT Feb 4, 2025
31119ef
Add support for unpadded shapes.
avoraTT Feb 5, 2025
24e65bb
Remove dprints.
avoraTT Feb 5, 2025
7251015
Fix bug in mcast bbox. Fix QKV output num cores.
avoraTT Feb 5, 2025
b67ea6d
#0: multi-link support added for 3 all_reduce. Link=3 fails with kern…
kpaigwar Feb 6, 2025
21f1779
#0: multi-link=3 works
kpaigwar Feb 6, 2025
e4a7b09
Add cleanup for multi-link.
avoraTT Feb 6, 2025
e0ce950
Rebase and fix/cleanup stuff.
avoraTT Feb 6, 2025
ea2edb7
Clean up pytest and enable trace.
avoraTT Feb 7, 2025
b6277d9
Adding gsem fix for multi-iter.
avoraTT Feb 7, 2025
4520362
#0: added api to subtract corerangesets
kpaigwar Feb 7, 2025
203e2e8
#0: updated choose_worker_cores function to omit reserved_cores
kpaigwar Feb 7, 2025
694ef4d
#0: fix placement of link worker cores
kpaigwar Feb 7, 2025
a131897
Add support for input shard not divisble by output shard.
avoraTT Feb 8, 2025
bcf48a6
added all reduce into llama ccl perf test and added proper measuremen…
caixunshiren Feb 10, 2025
b779894
Extend llama sharded all gather for LN
johanna-rock-tt Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/nightly/tg/ccl/test_ccl_async_TG_llama_nightly.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
teardown_fabric_interface,
create_global_semaphore_with_same_address,
)
from models.perf.benchmarking_utils import BenchmarkProfiler


def report_mismatches(golden, actual, max_printable=None):
Expand Down Expand Up @@ -63,7 +64,9 @@ def run_with_trace(
n_worker=None,
n_buffer=None,
num_iter=20,
warmup_iters=0,
use_all_gather_async=False,
profiler=BenchmarkProfiler(),
):
# Compile Run
logger.info("Compiling model")
Expand Down Expand Up @@ -97,44 +100,68 @@ def run_with_trace(

# Capture trace
logger.info("Capturing trace")
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(num_iter):
if use_all_gather_async:
logger.info("Running all-gather async")
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

def capture_trace(n_iters):
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(n_iters):
if use_all_gather_async:
logger.info("Running all-gather async")
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)
return trace_id

if warmup_iters > 0:
trace_id_warmup = capture_trace(warmup_iters)
trace_id = capture_trace(num_iter)

# Run the op
logger.info("Starting Trace perf test...")
profiler.start("all-gather-async-trace-warmup")
if warmup_iters > 0:
ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False)
ttnn.release_trace(mesh_device, trace_id_warmup)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)
profiler.end("all-gather-async-trace-warmup")

profiler.start("all-gather-async-trace")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)
profiler.end("all-gather-async-trace")
time_taken = profiler.get_duration("all-gather-async-trace") - profiler.get_duration(
"all-gather-async-trace-warmup"
)
effective_iter = num_iter - warmup_iters
logger.info(f"Time taken: {time_taken} s")
logger.info(f"Time per iter: {time_taken / effective_iter} s")
logger.info(f"Time per iter: {time_taken / effective_iter * 1e6} us")

return tt_out_tensor

Expand All @@ -156,10 +183,12 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
output_shard_spec: ttnn.ShardSpec = None,
num_all_gather_instances: int = 1,
num_iters: int = 1,
warmup_iters: int = 0,
cluster_axis: int = 0,
tile=(32, 32),
trace_mode=False,
debug=False,
profiler=BenchmarkProfiler(),
# New all-gather-async and persistent fabric params
use_all_gather_async=False,
enable_persistent_fabric=False,
Expand Down Expand Up @@ -269,7 +298,9 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
enable_persistent_fabric=enable_persistent_fabric,
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
warmup_iters=warmup_iters,
use_all_gather_async=use_all_gather_async,
profiler=profiler,
)

else:
Expand Down
Loading
Loading