Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Jul 29, 2024
1 parent f96d03c commit b5be3c6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
11 changes: 9 additions & 2 deletions tools/mlir_bench/mlp_bench.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ if ! [ "$(command -v ${BENCH_RUNNER})" ]; then
echo "Missing benchmark runner"
exit 1
fi
if [ "${BASELINE_MODEL}" ] && [ ${IS_DYNAMIC} ]; then
echo "Baseline models with dynamic shapes not supported"
exit 1
fi

# Kernel config.
INPUT_SIZES=( 1024 2048 4096 8192 )
Expand Down Expand Up @@ -86,8 +90,11 @@ for OUT_SIZE in "${OUTPUT_SIZES[@]}"; do
# No native support for bf16, use simple f16 instead.
PRECISION="f16"
fi
BENCH_FLAGS="-m ${MODEL_NAME} -d CPU -data_shape [${OUT_SIZE,IN_SIZE}]\
-ip ${PRECISION}"
if [ ${IS_DYNAMIC} ]; then
DATA_SHAPE=(-data_shape [${OUT_SIZE,IN_SIZE}])
fi
BENCH_FLAGS="-m ${MODEL_NAME} -d CPU \
-ip ${PRECISION} ${DATA_SHAPE[@]}"
${BENCH_RUNNER} ${BENCH_FLAGS} 2>/dev/null | \
sed -nE "s/.*\[ INFO \]\s*Median:\s*([0-9.]+).*/\\1/p"
done
Expand Down
13 changes: 7 additions & 6 deletions tools/mlir_bench/ov_model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,23 +184,24 @@ def forward(self, a, b):
return self.relu(c)


def baseline_MLP(sizes: list[int], data_type: str) -> tuple[nn.Model, list]:
def baseline_MLP(model_desc: str, data_type: str, is_dynamic: bool) -> tuple[nn.Model, list]:
sizes = get_layer_sizes(model_desc)
assert len(sizes) == 3, "Invalid baseline MLP sizes"
mlp = BaselineMLP(sizes, get_torch_type(data_type))
m = sizes[0]
n = sizes[1]
k = sizes[2]
input_shapes = get_layer_inputs(model_desc, is_dynamic)[0]
m = input_shapes[0]
n = input_shapes[1]
k = input_shapes[2]
ov_type = get_ov_type(data_type)
inputs = [(ov.PartialShape([m, k]), ov_type), (ov.PartialShape([k, n]), ov_type)]
return (mlp, inputs)


def generate_baseline_model(model_desc: str, data_type: str, shape_type: str, file_name: str):
model_name = get_layer_name(model_desc)
input_shapes = get_layer_inputs(model_desc, shape_type == 'dynamic')

if model_name == 'mlp':
baseline_tuple = baseline_MLP(*input_shapes, data_type)
baseline_tuple = baseline_MLP(model_desc, data_type, shape_type == 'dynamic')
else:
assert False, f"Unsupported baseline model data type {model_name}"

Expand Down

0 comments on commit b5be3c6

Please sign in to comment.