Skip to content

Commit

Permalink
added all reduce into llama ccl perf test and added proper measuremen…
Browse files Browse the repository at this point in the history
…t of e2e trace perf
  • Loading branch information
caixunshiren committed Feb 10, 2025
1 parent a131897 commit bcf48a6
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 252 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def run_with_trace(
n_worker=None,
n_buffer=None,
num_iter=20,
warmup_iters=0,
use_all_gather_async=False,
profiler=BenchmarkProfiler(),
):
Expand Down Expand Up @@ -99,49 +100,68 @@ def run_with_trace(

# Capture trace
logger.info("Capturing trace")
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(num_iter):
if use_all_gather_async:
logger.info("Running all-gather async")
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

def capture_trace(n_iters):
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(n_iters):
if use_all_gather_async:
logger.info("Running all-gather async")
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)
return trace_id

if warmup_iters > 0:
trace_id_warmup = capture_trace(warmup_iters)
trace_id = capture_trace(num_iter)

# Run the op
logger.info("Starting Trace perf test...")
profiler.start("all-gather-async-trace-warmup")
if warmup_iters > 0:
ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False)
ttnn.release_trace(mesh_device, trace_id_warmup)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)
profiler.end("all-gather-async-trace-warmup")

profiler.start("all-gather-async-trace")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)
profiler.end("all-gather-async-trace")
logger.info(f"Time taken: {profiler.get_duration('all-gather-async-trace')} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter * 1e6} us")
time_taken = profiler.get_duration("all-gather-async-trace") - profiler.get_duration(
"all-gather-async-trace-warmup"
)
effective_iter = num_iter - warmup_iters
logger.info(f"Time taken: {time_taken} s")
logger.info(f"Time per iter: {time_taken / effective_iter} s")
logger.info(f"Time per iter: {time_taken / effective_iter * 1e6} us")

return tt_out_tensor

Expand All @@ -163,6 +183,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
output_shard_spec: ttnn.ShardSpec = None,
num_all_gather_instances: int = 1,
num_iters: int = 1,
warmup_iters: int = 0,
cluster_axis: int = 0,
tile=(32, 32),
trace_mode=False,
Expand Down Expand Up @@ -277,6 +298,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
enable_persistent_fabric=enable_persistent_fabric,
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
warmup_iters=warmup_iters,
use_all_gather_async=use_all_gather_async,
profiler=profiler,
)
Expand Down
Loading

0 comments on commit bcf48a6

Please sign in to comment.