Skip to content

Commit

Permalink
#17829: binary_ng survey - add - same dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Feb 12, 2025
1 parent a0fa9d0 commit bec0f7d
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 0 deletions.
245 changes: 245 additions & 0 deletions tests/sweep_framework/sweeps/eltwise/binary/add/add_ng_survey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple
from functools import partial

import torch
import random
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random

# Override the default timeout in seconds for hang detection.
# TIMEOUT = 30

# random.seed(0)

# Parameters provided to the test vector generator are defined here.
# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values.
# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs.
# Developers can create their own generator functions and pass them to the parameters as inputs.
parameters = {
"t8_ngadd_bf4b": {
"input_shape": [{"self": [1, 1, 1024, 1024], "other": [1, 1, 1024, 1024]}],
# "input_shape": [{"self": [1, 1, 512, 512], "other": [1, 1, 512, 512]}], // for float32 and int32 dtypes
# "input_a_dtype": [ttnn.bfloat16],
# "input_b_dtype": [ttnn.bfloat16],
# "input_a_dtype": [ttnn.float32],
# "input_b_dtype": [ttnn.float32],
# "input_a_dtype": [ttnn.int32],
# "input_b_dtype": [ttnn.int32],
# "input_a_dtype": [ttnn.bfloat8_b],
# "input_b_dtype": [ttnn.bfloat8_b],
"input_a_dtype": [ttnn.bfloat4_b],
"input_b_dtype": [ttnn.bfloat4_b],
"input_a_layout": [ttnn.TILE_LAYOUT],
"input_b_layout": [ttnn.TILE_LAYOUT],
"input_mem_config": [
{"a_mem": "l1_interleaved", "b_mem": "l1_interleaved"},
{"a_mem": "l1_interleaved", "b_mem": "dram_interleaved"},
{"a_mem": "dram_interleaved", "b_mem": "l1_interleaved"},
{"a_mem": "dram_interleaved", "b_mem": "dram_interleaved"}, # l1 - dram combi
{"a_mem": "l1_height_sharded_rm", "b_mem": "l1_height_sharded_rm"},
{"a_mem": "dram_interleaved", "b_mem": "l1_height_sharded_rm"},
{"a_mem": "l1_height_sharded_rm", "b_mem": "dram_interleaved"}, # HS
{"a_mem": "l1_width_sharded_rm", "b_mem": "l1_width_sharded_rm"},
{"a_mem": "dram_interleaved", "b_mem": "l1_width_sharded_rm"},
{"a_mem": "l1_width_sharded_rm", "b_mem": "dram_interleaved"}, # WS
{"a_mem": "l1_block_sharded_rm", "b_mem": "l1_block_sharded_rm"},
{"a_mem": "dram_interleaved", "b_mem": "l1_block_sharded_rm"},
{"a_mem": "l1_block_sharded_rm", "b_mem": "dram_interleaved"}, # BS #row_major orientation
{"a_mem": "l1_height_sharded_cm", "b_mem": "l1_height_sharded_cm"},
{"a_mem": "dram_interleaved", "b_mem": "l1_height_sharded_cm"},
{"a_mem": "l1_height_sharded_cm", "b_mem": "dram_interleaved"}, # HS
{"a_mem": "l1_width_sharded_cm", "b_mem": "l1_width_sharded_cm"},
{"a_mem": "dram_interleaved", "b_mem": "l1_width_sharded_cm"},
{"a_mem": "l1_width_sharded_cm", "b_mem": "dram_interleaved"}, # WS
{"a_mem": "l1_block_sharded_cm", "b_mem": "l1_block_sharded_cm"},
{"a_mem": "dram_interleaved", "b_mem": "l1_block_sharded_cm"},
{"a_mem": "l1_block_sharded_cm", "b_mem": "dram_interleaved"}, # BS #col_major orientation
],
# "input_a_memory_config": [
# "l1_interleaved",
# "dram_interleaved",
# "l1_height_sharded_rm",
# "l1_width_sharded_rm",
# "l1_block_sharded_rm",
# ],
# "input_b_memory_config": [
# "l1_interleaved",
# "dram_interleaved",
# "l1_height_sharded_rm",
# "l1_width_sharded_rm",
# "l1_block_sharded_rm",
# ],
# "out_memory_config": [
# "l1_interleaved",
# "dram_interleaved",
# "l1_height_sharded_rm",
# "l1_width_sharded_rm",
# "l1_block_sharded_rm",
# ],
},
}


def return_mem_config(mem_config_string):
if mem_config_string == "l1_interleaved":
return ttnn.L1_MEMORY_CONFIG
elif mem_config_string == "dram_interleaved":
return ttnn.DRAM_MEMORY_CONFIG
elif mem_config_string == "l1_height_sharded_rm":
return ttnn.create_sharded_memory_config(
# shape=(512 // 8, 512),
shape=(1024 // 8, 1024),
core_grid=ttnn.CoreGrid(y=2, x=4),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
elif mem_config_string == "l1_height_sharded_cm":
return ttnn.create_sharded_memory_config(
# shape=(512, 512 // 8),
shape=(1024, 1024 // 8),
core_grid=ttnn.CoreGrid(y=2, x=4),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.COL_MAJOR,
use_height_and_width_as_shard_shape=True,
)
elif mem_config_string == "l1_width_sharded_rm":
return ttnn.create_sharded_memory_config(
# shape=(512, 512 // 8),
shape=(1024, 1024 // 8),
core_grid=ttnn.CoreGrid(y=2, x=4),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
elif mem_config_string == "l1_width_sharded_cm":
return ttnn.create_sharded_memory_config(
# shape=(512 // 8, 512),
shape=(1024 // 8, 1024),
core_grid=ttnn.CoreGrid(y=2, x=4),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.COL_MAJOR,
use_height_and_width_as_shard_shape=True,
)
elif mem_config_string == "l1_block_sharded_rm":
return ttnn.create_sharded_memory_config(
# shape=(512 // 2, 512 // 4),
shape=(1024 // 2, 1024 // 4),
core_grid=ttnn.CoreGrid(y=2, x=4),
strategy=ttnn.ShardStrategy.BLOCK,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
elif mem_config_string == "l1_block_sharded_cm":
return ttnn.create_sharded_memory_config(
# shape=(512 // 2, 512 // 4),
shape=(1024 // 2, 1024 // 4),
core_grid=ttnn.CoreGrid(y=2, x=4),
strategy=ttnn.ShardStrategy.BLOCK,
orientation=ttnn.ShardOrientation.COL_MAJOR,
use_height_and_width_as_shard_shape=True,
)
raise ("Input mem_config_string is not valid!")


# This is the run instructions for the test, defined by the developer.
# The run function must take the above-defined parameters as inputs.
# The runner will call this run function with each test vector, and the returned results from this function will be stored.
# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra.
def run(
input_shape,
input_a_dtype,
input_b_dtype,
input_a_layout,
input_b_layout,
input_mem_config,
# input_a_memory_config,
# input_b_memory_config,
# out_memory_config,
*,
device,
) -> list:
torch.manual_seed(0)

torch_input_tensor_a = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.bfloat16), input_a_dtype
)(input_shape["self"])

if isinstance(input_shape["other"], list):
torch_input_tensor_b = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.bfloat16), input_b_dtype
)(input_shape["other"])
else:
torch_input_tensor_b = torch.tensor(input_shape["other"], dtype=torch.bfloat16)

input_a_memory_config = input_mem_config["a_mem"]
input_b_memory_config = input_mem_config["b_mem"]

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
dtype=input_a_dtype,
layout=input_a_layout,
device=device,
memory_config=return_mem_config(input_a_memory_config),
)

input_tensor_b = ttnn.from_torch(
torch_input_tensor_b,
dtype=input_b_dtype,
layout=input_b_layout,
device=device,
memory_config=return_mem_config(input_b_memory_config),
)

if input_a_dtype == ttnn.bfloat8_b:
torch_input_tensor_a = ttnn.to_torch(input_tensor_a)

if input_b_dtype == ttnn.bfloat8_b:
torch_input_tensor_b = ttnn.to_torch(input_tensor_b)

if input_a_dtype == ttnn.bfloat4_b:
torch_input_tensor_a = ttnn.to_torch(input_tensor_a)

if input_b_dtype == ttnn.bfloat4_b:
torch_input_tensor_b = ttnn.to_torch(input_tensor_b)

golden_function = ttnn.get_golden_function(ttnn.experimental.add)
torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b)

start_time = start_measuring_time()
# result = ttnn.experimental.add(input_tensor_a, input_tensor_b, memory_config=return_mem_config(out_memory_config))
result = ttnn.experimental.add(input_tensor_a, input_tensor_b)
output_tensor = ttnn.to_torch(result)
e2e_perf = stop_measuring_time(start_time)

return [check_with_pcc(torch_output_tensor, output_tensor, pcc=0.99), e2e_perf]


# sweeps output
# using ttnn.add - all 22 vectors pass
# | t4_add | 22 | 0 | 0 | 0 | 0 | 0 | 0 |
# using ttnn.experimental.add - 16 vectors pass (bf16)
# | t5_ngadd | 16 | 6 | 0 | 0 | 0 | 0 | 0 |
# using ttnn.experimental.add - 10 vectors pass (fp32) - 6 L1 fail (1,1,1024, 1024)
# | t5_ngadd_fp32 | 10 | 6 | 0 | 0 | 6 | 0 | 0 |
# using ttnn.experimental.add - 16 vectors pass (fp32) -(1,1,512, 512)
# | t6_ngadd_fp32 | 16 | 6 | 0 | 0 | 0 | 0 | 0 |
# using ttnn.experimental.add - 16 vectors pass (int32) -(1,1,512, 512)
# | t6_ngadd_int32 | 16 | 6 | 0 | 0 | 0 | 0 | 0 |
# using ttnn.experimental.add - 4 vectors pass (bf8b) - all sharding configs fail with ( we typecast bf8b -> bf16)
# message TT_FATAL @ /home/ubuntu/Kalai/tt-metal/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_sharded_program_factory.cpp:45: input_tile_size == output_tile_size
# info:
# Input and output tile size should be same
# | t7_ngadd_bf8b | 4 | 18 | 0 | 0 | 0 | 0 | 0 |
# using ttnn.experimental.add - 4 vectors pass (bf4b) - all sharding configs fail with ( we typecast bf4b -> bf16)
# message TT_FATAL @ /home/ubuntu/Kalai/tt-metal/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_sharded_program_factory.cpp:45: input_tile_size == output_tile_size
# info:
# Input and output tile size should be same
# | t8_ngadd_bf4b | 4 | 18 | 0 | 0 | 0 | 0 | 0 |
5 changes: 5 additions & 0 deletions tests/sweep_framework/sweeps_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def run(test_module, input_queue, output_queue):
output_queue.put([status, message, e2e_perf, perf_result])
else:
output_queue.put([status, message, e2e_perf, None])
if not status:
print("-----------------------")
print("current parameter ", test_vector)
print("STATUS", status)
print("message", message)
except Empty as e:
try:
# Run teardown in mesh_device_fixture
Expand Down

0 comments on commit bec0f7d

Please sign in to comment.