Skip to content

Commit

Permalink
Added better error messages for type validators (stack trace) (#7999)
Browse files Browse the repository at this point in the history
* Added better error messages for type validators (stack trace)

* changed error message format (moved newlines for better readability.)

* Fixed linting issue.

* Fixed unit test (it checks the first element of the new pair)
cptspacemanspiff authored Jan 28, 2025
1 parent 9bd18f6 commit 3be1c5e
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion exir/tests/test_arg_validator.py
Original file line number Diff line number Diff line change
@@ -64,7 +64,7 @@ def forward(self, x):
ops.edge.aten._log_softmax.default.name(),
)
self.assertDictEqual(
validator.violating_ops[key],
validator.violating_ops[key][0],
{
"self": torch.bfloat16,
"__ret_0": torch.bfloat16,
8 changes: 4 additions & 4 deletions exir/verification/arg_validator.py
Original file line number Diff line number Diff line change
@@ -37,9 +37,9 @@ class EdgeOpArgValidator(torch.fx.Interpreter):

def __init__(self, graph_module: torch.fx.GraphModule) -> None:
super().__init__(graph_module)
self.violating_ops: Dict[EdgeOpOverload, Dict[str, Optional[torch.dtype]]] = (
defaultdict(dict)
)
self.violating_ops: Dict[
EdgeOpOverload, Tuple[Dict[str, Optional[torch.dtype]], torch.fx.Node]
] = defaultdict(dict)

def run_node(self, n: torch.fx.Node) -> None:
self.node = n
@@ -125,5 +125,5 @@ def call_function( # noqa: C901 # pyre-fixme[14]

valid = target._schema.dtype_constraint.validate(tensor_arg_types)
if not valid:
self.violating_ops[target] = tensor_arg_types
self.violating_ops[target] = (tensor_arg_types, self.node)
return super().call_function(target, args, kwargs) # pyre-fixme[6]
10 changes: 8 additions & 2 deletions exir/verification/verifier.py
Original file line number Diff line number Diff line change
@@ -189,9 +189,15 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
return

if validator.violating_ops:
error_msg = ""
for op, node in validator.violating_ops.items():
# error_msg += f"#####################################################\n"
error_msg += f"\nOperator: {op} with args: {node[0]}\n"
error_msg += f"stack trace: {node[1].stack_trace}\n"
# error_msg += f"#####################################################\n"
raise SpecViolationError(
f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}"
"Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding "
f"These operators are taking Tensor inputs with mismatched dtypes:\n{error_msg}"
"Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding outputs."
)


0 comments on commit 3be1c5e

Please sign in to comment.