Skip to content

Commit

Permalink
[e2e] add execution error in compatibility test
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Aug 16, 2024
1 parent 0e05851 commit 51470c4
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 43 deletions.
96 changes: 56 additions & 40 deletions tests/compatibility_test/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,23 @@
import torch
import numpy as np
import re
import traceback

from reporting import TestResult


class BRTBackend:

def __init__(self, device, brt_file_path):
_stream = None
self.device = None
if device == "CPU":
self.session = brt.Session(device=device.upper(), )
self.session = brt.Session(
device=device.upper(),
)
self.device = "cpu"
_stream = None
else:
raise NotImplementedError(
f"Compatible test for {device} not implement")
raise NotImplementedError(f"Compatible test for {device} not implement")

self.session.load(brt_file_path)
self.req = self.session.new_request_context(_stream)
Expand All @@ -45,10 +46,12 @@ def _generate_torch_outputs(self):
outputs = []
for offset in self.session.get_output_arg_offsets():
outputs.append(
torch.empty(self.session.get_static_shape(offset),
dtype=brt_dtype_to_torch_dtype(
self.session.get_data_type(offset)),
device=self.device))
torch.empty(
self.session.get_static_shape(offset),
dtype=brt_dtype_to_torch_dtype(self.session.get_data_type(offset)),
device=self.device,
)
)
return outputs

def compare(self, inputs, goldens):
Expand All @@ -57,12 +60,10 @@ def compare(self, inputs, goldens):
assert len(self.session.get_output_arg_offsets()) == len(outputs)
assert len(outputs) == len(goldens)
for offset, arg in zip(self.session.get_input_arg_offsets(), inputs):
assert list(self.session.get_static_shape(offset)) == list(
arg.shape)
assert list(self.session.get_static_shape(offset)) == list(arg.shape)
self.req.bind_arg(offset, arg.data_ptr())
for offset, ret in zip(self.session.get_output_arg_offsets(), outputs):
assert list(self.session.get_static_shape(offset)) == list(
ret.shape)
assert list(self.session.get_static_shape(offset)) == list(ret.shape)
self.req.bind_arg(offset, ret.data_ptr())
self.req.finish_io_binding()
self.req.run()
Expand All @@ -71,32 +72,47 @@ def compare(self, inputs, goldens):


def run_and_check_mlir(target, name, inp_files, out_files, byre_file):

_device = None
if target == "cpu":
_device = "CPU"

brt_backend = BRTBackend(device=_device, brt_file_path=byre_file)

cmp_res = []
for idx, (input_file, target_file) in enumerate(zip(inp_files, out_files)):
inp = np.load(input_file, allow_pickle=True)
inp = [
torch.from_numpy(inp[f]).contiguous().to(_device.lower())
for f in inp.files
try:
_device = None
if target == "cpu":
_device = "CPU"

brt_backend = BRTBackend(device=_device, brt_file_path=byre_file)

cmp_res = []
for idx, (input_file, target_file) in enumerate(zip(inp_files, out_files)):
inp = np.load(input_file, allow_pickle=True)
inp = [
torch.from_numpy(inp[f]).contiguous().to(_device.lower())
for f in inp.files
]
tgt = np.load(target_file, allow_pickle=True)
tgt = [
torch.from_numpy(tgt[f]).contiguous().to(_device.lower())
for f in tgt.files
]
if brt_backend.compare(inp, tgt):
cmp_res.append(
TestResult(
name + str(idx), execution_error=None, numerical_error=None
)
)
else:
cmp_res.append(
TestResult(
name + str(idx),
execution_error=None,
numerical_error=f"input is {input_file}, output not match {target_file}",
)
)
return cmp_res
except Exception as e:
return [
TestResult(
name,
execution_error="".join(
traceback.format_exception(type(e), e, e.__traceback__)
),
numerical_error=None,
)
]
tgt = np.load(target_file, allow_pickle=True)
tgt = [
torch.from_numpy(tgt[f]).contiguous().to(_device.lower())
for f in tgt.files
]
if brt_backend.compare(inp, tgt):
cmp_res.append(TestResult(name + str(idx), numerical_error=None))
else:
cmp_res.append(
TestResult(
name + str(idx),
numerical_error=
f"input is {input_file}, output not match {target_file}"))

return cmp_res
12 changes: 10 additions & 2 deletions tests/compatibility_test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,17 @@ def parse_args():
required=True,
help="Directory has test cases",
)
parser.add_argument(
"-v",
"--verbose",
default=False,
action="store_true",
)
args = parser.parse_args()
return args


def run(testdir):
def run(testdir, verbose=False):
result = []
conf_file = os.path.join(testdir, "testcase.json")
if not os.path.exists(conf_file):
Expand All @@ -64,6 +70,8 @@ def run(testdir):
)
if not os.path.exists(byre_file):
raise RuntimeError(f"byre file{byre_file} not found")
if verbose:
print(f"Running {name}")
result += run_and_check_mlir(target, name, input_files,
golden_files, byre_file)
return result
Expand All @@ -72,7 +80,7 @@ def run(testdir):
def main():
args = parse_args()

results = run(args.testdir)
results = run(args.testdir, verbose=args.verbose)

failed = report_results(results)
sys.exit(1 if failed else 0)
Expand Down
8 changes: 7 additions & 1 deletion tests/compatibility_test/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@

class TestResult(NamedTuple):
unique_name: str
execution_error: Optional[str]
numerical_error: Optional[str]


def report_results(results: List[TestResult]):
fail_case = []
pass_case = []
for result in results:
if result.numerical_error is not None:
if result.execution_error is not None:
fail_case.append([
result.unique_name, "execution failed: " + result.unique_name +
"\n" + result.execution_error
])
elif result.numerical_error is not None:
fail_case.append([
result.unique_name, "numerical failed: " + result.unique_name +
"\n" + result.numerical_error
Expand Down

0 comments on commit 51470c4

Please sign in to comment.