From e978e4db5ff9b025fde6e59e91a66c9f8bbdacb5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 30 Aug 2024 22:52:54 +0000 Subject: [PATCH 1/8] Introduce OpSignature and deprecate ParamSchema --- onnxscript/_internal/_schemas.py | 553 ++++++++++++++++++++++++++ onnxscript/_internal/_schemas_test.py | 180 +++++++++ onnxscript/values.py | 90 ++++- 3 files changed, 808 insertions(+), 15 deletions(-) create mode 100644 onnxscript/_internal/_schemas.py create mode 100644 onnxscript/_internal/_schemas_test.py diff --git a/onnxscript/_internal/_schemas.py b/onnxscript/_internal/_schemas.py new file mode 100644 index 000000000..551ea60cc --- /dev/null +++ b/onnxscript/_internal/_schemas.py @@ -0,0 +1,553 @@ +from __future__ import annotations + +import collections.abc +import dataclasses +import inspect +import logging +import types +import typing +from typing import ( + Any, + Iterator, + Mapping, + Optional, + Sequence, + TypeVar, + Union, +) + +import onnx +import onnxscript +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +# A special value to indicate that the default value is not specified +class _Empty: + def __repr__(self): + return "_EMPTY_DEFAULT" + + +_EMPTY_DEFAULT = _Empty() + +# Map from python type to corresponding ONNX AttributeProto type +_PY_TYPE_TO_ATTR_TYPE = { + float: ir.AttributeType.FLOAT, + int: ir.AttributeType.INT, + str: ir.AttributeType.STRING, + bool: ir.AttributeType.INT, + ir.Tensor: ir.AttributeType.TENSOR, + ir.TensorProtocol: ir.AttributeType.TENSOR, + ir.Graph: ir.AttributeType.GRAPH, + ir.GraphProtocol: ir.AttributeType.GRAPH, +} + +# Map from python type to corresponding ONNX AttributeProto type, +# for repeated (i.e., list of) values +_LIST_TYPE_TO_ATTR_TYPE = { + float: ir.AttributeType.FLOATS, + int: ir.AttributeType.INTS, + str: ir.AttributeType.STRINGS, + bool: ir.AttributeType.INTS, + ir.Tensor: ir.AttributeType.TENSORS, + ir.TensorProtocol: ir.AttributeType.TENSORS, + ir.Graph: ir.AttributeType.GRAPHS, + ir.GraphProtocol: ir.AttributeType.GRAPHS, +} + +_ALL_VALUE_TYPES = ( + {ir.TensorType(dtype) for dtype in ir.DataType} + | {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType} + | {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType} +) + +# TypeAnnotationValue represents the (value of) valid type-annotations recognized +# by ONNX Script. Currently, it supports +# - float, int, str (primitive attribute types) +# - Sequence[float], Sequence[int], Sequence[str] (attribute types) +# - Tensor types +# - Sequence[Tensor] types +# - Union of above 2 +# - TypeVars with above bounds +# - Above types with annotation attached +TypeAnnotationValue = Any + + +@dataclasses.dataclass(frozen=True) +class TypeConstraintParam: + """Type constraint for a parameter. + + Attributes: + name: Name of the parameter. E.g. "TFloat" + allowed_types: Allowed types for the parameter. + """ + + name: str + allowed_types: set[ir.TypeProtocol] + description: str = "" + + def __hash__(self) -> int: + return hash((self.name, tuple(self.allowed_types))) + + def __str__(self) -> str: + allowed_types_str = " | ".join(str(t) for t in self.allowed_types) + return f"{self.name}={allowed_types_str}" + + @classmethod + def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam: + return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description) + + @classmethod + def any_value(cls, name: str, description: str = "") -> TypeConstraintParam: + return cls(name, _ALL_VALUE_TYPES, description) # type: ignore + + +@dataclasses.dataclass(frozen=True) +class Parameter: + """A formal parameter of an operator.""" + + name: str + type_constraint: TypeConstraintParam + required: bool + variadic: bool + default: Any = _EMPTY_DEFAULT + # TODO: Add other properties too + + def __str__(self) -> str: + type_str = self.type_constraint.name + if self.has_default(): + return f"{self.name}: {type_str} = {self.default}" + return f"{self.name}: {type_str}" + + def has_default(self) -> bool: + return self.default is not _EMPTY_DEFAULT + + +@dataclasses.dataclass(frozen=True) +class AttributeParameter: + name: str + type: ir.AttributeType + required: bool + default: ir.Attr | None = None + + def __str__(self) -> str: + type_str = self.type.name + if self.has_default(): + return f"{self.name}: {type_str} = {self.default}" + return f"{self.name}: {type_str}" + + def has_default(self) -> bool: + return self.default is not None + + +def _get_type_from_str( + type_str: str, +) -> ir.TensorType | ir.SequenceType | ir.OptionalType: + """Converter a type_str from ONNX Opschema to ir.TypeProtocol. + + A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))". + """ + + # TODO: Upstream this to IR + + # Split the type_str a sequence types and dtypes + # 1. Remove the ending ")" + striped = type_str.rstrip(")") + # 2. Split the type_str by "(" + type_parts = striped.split("(") + + # Convert the dtype to ir.DataType + dtype = ir.DataType[type_parts[-1].upper()] + + # Create a place holder type first + type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED) + + # Construct the type + for type_part in reversed(type_parts[:-1]): + if type_part == "tensor": + type_ = ir.TensorType(dtype) + elif type_part == "seq": + type_ = ir.SequenceType(type_) + elif type_part == "optional": + type_ = ir.OptionalType(type_) + else: + raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'") + return type_ # type: ignore[return-value] + + +def _convert_formal_parameter( + param: onnx.defs.OpSchema.FormalParameter, + type_constraints: Mapping[str, TypeConstraintParam], +) -> Parameter: + """Convert a formal parameter from ONNX Opschema to Parameter.""" + if param.type_str in type_constraints: + type_constraint = type_constraints[param.type_str] + else: + # param.type_str can be a plain type like 'int64'. + type_constraint = TypeConstraintParam( + name=param.name, + allowed_types={_get_type_from_str(param.type_str)}, # type: ignore + ) + return Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional, + variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic, + ) + + +def _is_optional(type_: type) -> bool: + """Returns whether a type_ is an Optional.""" + origin_type = typing.get_origin(type_) + if origin_type is Union and type(None) in typing.get_args(type_): + # Python < 3.10 + return True + if origin_type is Optional: + # Python >= 3.10 + return True + if ( + hasattr(types, "UnionType") + and origin_type is types.UnionType + and type(None) in typing.get_args(type_) + ): + # Python >= 3.10 + return True + return False + + +def _get_attr_type(type_: type) -> ir.AttributeType: + """Obtain the type of the attribute from a Python class.""" + try: + if type_ in _PY_TYPE_TO_ATTR_TYPE: + return _PY_TYPE_TO_ATTR_TYPE[type_] + origin_type = typing.get_origin(type_) + if origin_type is None: + return ir.AttributeType.UNDEFINED + if origin_type in ( + collections.abc.Sequence, + Sequence, + typing.List, + list, + typing.Tuple, + tuple, + ): + inner_type = typing.get_args(type_)[0] + if inner_type in _LIST_TYPE_TO_ATTR_TYPE: + return _LIST_TYPE_TO_ATTR_TYPE[inner_type] + except TypeError: + logger.warning("TypeError when checking %s.", type_, exc_info=True) + return ir.AttributeType.UNDEFINED + + +def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None: + """Returns the name of the type constraint for a given type annotation. + + Args: + type_: A Python type. + + Returns: + The name of the type constraint if it is a TypeVar. + - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. + """ + if isinstance(type_, TypeVar): + return type_.__name__ + if _is_optional(type_): + subtypes = typing.get_args(type_) + for subtype in subtypes: + if subtype is type(None): + continue + type_param_name = _get_type_constraint_name(subtype) + return type_param_name if type_param_name else None + origin_type = typing.get_origin(type_) + if isinstance(origin_type, type) and issubclass(origin_type, Sequence): + subtypes = typing.get_args(type_) + type_param_name = _get_type_constraint_name(subtypes[0]) + return f"Sequence_{type_param_name}" if type_param_name else None + return None + + +def _get_allowed_types_from_type_annotation( + type_: TypeAnnotationValue, +) -> set[ir.TypeProtocol]: + """Obtain the allowed types from a type annotation.""" + if type_ is onnxscript.onnx_types.TensorType: + # Any tensor type + return {ir.TensorType(dtype) for dtype in ir.DataType} + + allowed_types: set[ir.TypeProtocol] + + if isinstance(type_, TypeVar): + allowed_types = set() + if constraints := type_.__constraints__: + for constraint in constraints: + allowed_types.update( + _get_allowed_types_from_type_annotation(constraint) + ) + else: + bound = type_.__bound__ + if bound is None: + allowed_types = _ALL_VALUE_TYPES # type: ignore[assignment] + else: + allowed_types.update(_get_allowed_types_from_type_annotation(bound)) + return allowed_types + if hasattr(type_, "dtype"): + # A single tensor type like INT64, FLOAT, etc. + return {ir.TensorType(ir.DataType(type_.dtype))} + if _is_optional(type_): + allowed_types = set() + subtypes = typing.get_args(type_) + for subtype in subtypes: + if subtype is type(None): + continue + allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) + # NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful. + return allowed_types + + origin_type = typing.get_origin(type_) + if origin_type is Union: + allowed_types = set() + subtypes = typing.get_args(type_) + for subtype in subtypes: + assert ( + subtype is not type(None) + ), "Union should not contain None type because it is handled by _is_optional." + allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) + return allowed_types + + if isinstance(origin_type, type) and issubclass(origin_type, Sequence): + subtypes = typing.get_args(type_) + return { + ir.SequenceType(t) + for t in _get_allowed_types_from_type_annotation(subtypes[0]) + } + + # Allow everything by default + return _ALL_VALUE_TYPES # type: ignore[return-value] + + +@dataclasses.dataclass +class OpSignature: + """Schema for an operator. + + Attributes: + domain: Domain of the operator. E.g. "". + name: Name of the operator. E.g. "Add". + overload: Overload name of the operator. + params: Input parameters. When the op is an ONNX function definition, + the order is according to the function signature. This mean we can + interleave ONNX inputs and ONNX attributes in the list. + outputs: Output parameters. + """ + + domain: str + name: str + overload: str + params: Sequence[Parameter | AttributeParameter] + outputs: Sequence[Parameter] + params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field( + init=False, repr=False + ) + + def __post_init__(self): + self.params_map = {param.name: param for param in self.params} + + def get(self, name: str) -> Parameter | AttributeParameter: + return self.params_map[name] + + def __contains__(self, name: str) -> bool: + return name in self.params_map + + def __iter__(self) -> Iterator[Parameter | AttributeParameter]: + return iter(self.params) + + def __str__(self) -> str: + domain = self.domain or "''" + # TODO: Double check the separator for overload + overload = f"::{self.overload}" if self.overload else "" + params = ", ".join(str(param) for param in self.params) + outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs) + type_constraints = {} + for param in self.params: + if isinstance(param, Parameter): + type_constraints[param.type_constraint.name] = param.type_constraint + for param in self.outputs: + type_constraints[param.type_constraint.name] = param.type_constraint + type_constraints_str = ", ".join( + str(type_constraint) for type_constraint in type_constraints.values() + ) + return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}" + + @classmethod + def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: + """Produce an OpSignature from an ONNX Opschema.""" + type_constraints = { + constraint.type_param_str: TypeConstraintParam( + name=constraint.type_param_str, + allowed_types={ + _get_type_from_str(type_str) + for type_str in constraint.allowed_type_strs + }, + description=constraint.description, + ) + for constraint in op_schema.type_constraints + } + + params = [ + _convert_formal_parameter(param, type_constraints) + for param in op_schema.inputs + ] + + for param in op_schema.attributes.values(): + default_attr = ( + ir.serde.deserialize_attribute(param.default_value) + if param.default_value is not None + else None + ) + if default_attr is not None: + # Set the name of the default attribute because it may have a different name from the parameter + default_attr.name = param.name + params.append( + AttributeParameter( + name=param.name, + type=ir.AttributeType(param.type), # type: ignore[arg-type] + required=param.required, + default=default_attr, # type: ignore[arg-type] + ) + ) + + outputs = [ + _convert_formal_parameter(param, type_constraints) + for param in op_schema.outputs + ] + + return cls( + domain=op_schema.domain, + name=op_schema.name, + overload="", + params=params, + outputs=outputs, + ) + + @classmethod + def from_function( + cls, func, domain: str, name: str | None = None, overload: str = "" + ) -> OpSignature: + """Produce an OpSignature from a function using type annotation.""" + + py_signature = inspect.signature(func) + # Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases + # https://github.com/python/cpython/issues/102405 + type_hints = typing.get_type_hints(func) + + params = [] + # Create a mapping from type to a unique name + type_constraints: dict[str, TypeConstraintParam] = {} + + for param in py_signature.parameters.values(): + if param.name not in type_hints: + logger.warning( + "Missing annotation for parameter '%s' from %s. Treating as an Input.", + param.name, + py_signature, + ) + type_constraints[param.name] = TypeConstraintParam.any_value( + f"T_{param.name}" + ) + else: + type_ = type_hints[param.name] + if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: + # Construct the default attribute + if param.default is not inspect.Parameter.empty: + # TODO: Use ir_convenience instead to handle int as float + default = ir.Attr(param.name, attr_type, param.default) + else: + default = None + params.append( + AttributeParameter( + name=param.name, + type=attr_type, + required=param.default is inspect.Parameter.empty, + default=default, + ) + ) + else: + # Obtain the type constraint from the type annotation + + # 1. Get a type constraint name from the type annotation + # If the type annotation is a TypeVar or Optional[TypeVar], get its name + # Otherwise, name it T_{param.name} + type_constraint_name = _get_type_constraint_name(type_) + if type_constraint_name is None: + type_constraint_name = f"T_{param.name}" + + # 2. If the type constraint param is already initialized, use it + if type_constraint_name in type_constraints: + type_constraint = type_constraints[type_constraint_name] + else: + # 3. Otherwise, create a new TypeConstraintParam + type_constraint = TypeConstraintParam( + name=type_constraint_name, + allowed_types=_get_allowed_types_from_type_annotation( + type_ + ), + ) + type_constraints[type_constraint_name] = type_constraint + # 4. Create Parameter + params.append( + Parameter( # type: ignore[arg-type] + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) + ) + + return_type = type_hints.get("return") + + outputs = [] + if return_type is None: + # No returns + pass + else: + if typing.get_origin(return_type) is tuple: + # Multiple returns + return_types = typing.get_args(return_type) + else: + return_types = [return_type] # type: ignore[assignment] + + for i, return_type_i in enumerate(return_types): + if ( + return_param_name := _get_type_constraint_name(return_type_i) + ) in type_constraints: + type_constraint = type_constraints[return_param_name] + else: + return_param_name = f"TReturn{i}" + type_constraint = TypeConstraintParam( + name=return_param_name, + allowed_types=_get_allowed_types_from_type_annotation( + return_type_i + ), + ) + type_constraints[return_param_name] = type_constraint + outputs.append( + Parameter( + name=return_param_name, + type_constraint=type_constraint, + required=True, + variadic=False, + default=_EMPTY_DEFAULT, + ) + ) + + return cls( + domain=domain, + name=name or func.__name__, + overload=overload, + params=params, + outputs=outputs, + ) diff --git a/onnxscript/_internal/_schemas_test.py b/onnxscript/_internal/_schemas_test.py new file mode 100644 index 000000000..605407c7b --- /dev/null +++ b/onnxscript/_internal/_schemas_test.py @@ -0,0 +1,180 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import unittest +from typing import Any, Optional, Sequence, TypeVar, Union + +import onnxscript +import onnxscript.testing +import parameterized +from onnxscript import FLOAT, INT64, ir + +from torch_onnx import _schemas + +_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) +_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) +_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) + + +class TypeConversionFunctionsTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ( + "tensor_type_all", + onnxscript.onnx_types.TensorType, + {ir.TensorType(dtype) for dtype in ir.DataType}, + ), + ("tensor_type", INT64, {ir.TensorType(ir.DataType.INT64)}), + ( + "tensor_type_union", + Union[INT64, FLOAT], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "tensor_type_variadic_shape", + INT64[...], + {ir.TensorType(ir.DataType.INT64)}, + ), + ("tensor_type_shape", INT64[10], {ir.TensorType(ir.DataType.INT64)}), + ( + "type_var_constraints", + _TestTypeVarConstraints, + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "type_bound_one", + _TestTypeVarOneBound, + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "type_bound_two", + _TestTypeVarTwoBound, + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "optional_tensor_type_all", + Optional[onnxscript.onnx_types.TensorType], + {ir.TensorType(dtype) for dtype in ir.DataType}, + ), + ( + "optional_tensor_type", + Optional[INT64], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_tensor_type_union", + Optional[Union[INT64, FLOAT]], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "optional_tensor_type_variadic_shape", + Optional[INT64[...]], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_tensor_type_shape", + Optional[INT64[10]], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_type_var_constraints", + Optional[_TestTypeVarConstraints], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "optional_type_bound_one", + Optional[_TestTypeVarOneBound], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_type_bound_two", + Optional[_TestTypeVarTwoBound], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "sequence_type_all", + Sequence[onnxscript.onnx_types.TensorType], + {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType}, + ), + ( + "sequence_type", + Sequence[INT64], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "union_sequence_type", + Union[Sequence[INT64], Sequence[FLOAT]], + { + ir.SequenceType(ir.TensorType(ir.DataType.INT64)), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + }, + ), + ( + "sequence_type_variadic_shape", + Sequence[INT64[...]], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "sequence_type_shape", + Sequence[INT64[10]], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "sequence_type_var_constraints", + Sequence[_TestTypeVarConstraints], + { + ir.SequenceType(ir.TensorType(ir.DataType.INT64)), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + }, + ), + ( + "sequence_type_bound_one", + Sequence[_TestTypeVarOneBound], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "sequence_type_bound_two", + Sequence[_TestTypeVarTwoBound], + { + ir.SequenceType(ir.TensorType(ir.DataType.INT64)), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + }, + ), + ] + ) + def test_pytype_to_ir_type(self, _, pytype: Any, expected: set[ir.TypeProtocol]): + self.assertEqual( + _schemas._get_allowed_types_from_type_annotation(pytype), expected + ) + + @parameterized.parameterized.expand( + [ + ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"), + ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"), + ( + "optional_type_var", + Optional[_TestTypeVarOneBound], + "_TestTypeVarOneBound", + ), + ( + "sequence_type_var", + Sequence[_TestTypeVarOneBound], + "Sequence__TestTypeVarOneBound", + ), + ("normal_type", INT64, None), + ("union_type", Union[INT64, FLOAT], None), + ("optional_type", Optional[INT64], None), + ("sequence_type", Sequence[INT64], None), + ("optional_sequence_type", Optional[Sequence[INT64]], None), + ("optional_union_type", Optional[Union[INT64, FLOAT]], None), + ] + ) + def test_get_type_constraint_name(self, _: str, pytype: Any, expected: str | None): + self.assertEqual(_schemas._get_type_constraint_name(pytype), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/values.py b/onnxscript/values.py index f47c64f70..ca54313b4 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -24,7 +24,7 @@ from onnxscript import converter as converter_module from onnxscript import irbuilder, sourceinfo, type_annotation -from onnxscript._internal import ast_utils, deprecation +from onnxscript._internal import _schemas, ast_utils, deprecation _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { onnx.defs.OpSchema.AttrType.FLOAT: float, @@ -173,7 +173,7 @@ def _get_attribute_value(attr_proto: onnx.AttributeProto) -> Any: return onnx.helper.get_attribute_value(attr_proto) -def param_schemas_from_op_schema( +def _param_schemas_from_op_schema( op_schema: onnx.defs.OpSchema, ) -> tuple[ParamSchema, ...]: """Get the parameter schemas from an ONNX OpSchema.""" @@ -222,7 +222,7 @@ def _param_schema_from_function_ir_attr(attr: irbuilder.IRAttributeParameter): ) -def param_schemas_from_function_ir( +def _param_schemas_from_function_ir( function_ir: irbuilder.IRFunction, ) -> tuple[ParamSchema, ...]: """Get the parameter schemas from a FunctionIR.""" @@ -259,7 +259,7 @@ def opset(self) -> Opset: ... @property def op_schema(self) -> Optional[onnx.defs.OpSchema]: ... - def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: ... + def signature(self) -> Optional[_schemas.OpSignature]: ... class Op(OpLike): @@ -274,18 +274,19 @@ class Op(OpLike): """ def __init__( - self, opset: Opset, opname: str, op_schema: Optional[onnx.defs.OpSchema] = None + self, opset: Opset, name: str, op_schema: Optional[onnx.defs.OpSchema] = None ) -> None: self._opset = opset - self._name = opname - self._op_schema = op_schema or opset[opname] + self._name = name + self._op_schema = op_schema or opset[name] + self._signature: Optional[_schemas.OpSignature] = None self._param_schemas: Optional[tuple[ParamSchema, ...]] = None if self._op_schema is None: logger.debug( "An OpSchema was not provided for Op '%s' and " "there is not one found in opset '%s'.", - opname, + name, opset, ) @@ -312,10 +313,32 @@ def opset(self) -> Opset: def op_schema(self) -> Optional[onnx.defs.OpSchema]: return self._op_schema + @deprecation.deprecated( + since="0.1", + removed_in="the future", + instructions="check if '.op_schema' is not None instead", + ) def has_schema(self) -> bool: """Returns True if this op has an OpSchema.""" return self.op_schema is not None + @property + def signature(self) -> Optional[_schemas.OpSignature]: + """Returns the signature of this op.""" + if self._signature is not None: + return self._signature + + if self.op_schema is None: + return None + + self._signature = _schemas.OpSignature.from_op_schema(self.op_schema) + return self._signature + + @deprecation.deprecated( + since="0.1", + removed_in="the future", + instructions="use '.signature' instead", + ) def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: """Returns the parameter schemas for this op, if it has one.""" if self._param_schemas is not None: @@ -325,7 +348,7 @@ def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: if op_schema is None: return None - self._param_schemas = param_schemas_from_op_schema(op_schema) + self._param_schemas = _param_schemas_from_op_schema(op_schema) return self._param_schemas @@ -362,7 +385,7 @@ def as_tuple(self) -> tuple[str, list[str], str]: return (self.name, self.allowed_types, self.description) -def op_schema_from_function_ir( +def _op_schema_from_function_ir( function_ir: irbuilder.IRFunction, opset: Opset ) -> onnx.defs.OpSchema: """Construct an ONNX OpSchema from an IRFunction.""" @@ -486,7 +509,7 @@ def __init__( @property @deprecation.deprecated( since="0.1", - removed_in="0.3", + removed_in="the future", instructions="use '.name' instead", ) def opname(self) -> str: @@ -500,10 +523,23 @@ def op_schema(self) -> Optional[onnx.defs.OpSchema]: if self._op_schema is not None: return self._op_schema - self._op_schema = op_schema_from_function_ir(self.function_ir, self.opset) + self._op_schema = _op_schema_from_function_ir(self.function_ir, self.opset) return self._op_schema + def signature(self) -> Optional[_schemas.OpSignature]: + """Returns the signature of this op.""" + if self._signature is not None: + return self._signature + + if self.op_schema is None: + return None + + self._signature = _schemas.OpSignature.from_function( + self.function, domain=self.function_ir.domain, name=self.name + ) + return self._signature + def __getitem__(self, instance): """Returns a lambda to evaluate function using given evaluator instance. @@ -531,6 +567,11 @@ def __call__(self, *args, **kwargs): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.function!r})" + @deprecation.deprecated( + since="0.1", + removed_in="the future", + instructions="use '.signature' instead", + ) def param_schemas(self) -> tuple[ParamSchema, ...]: """Returns the parameter schemas of this function.""" if self._param_schemas is not None: @@ -539,7 +580,7 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # NOTE: We generate the parameter schemas from the function_ir instead # of relying on the auto generated OpSchema because we need to preserve the keyword # argument order from the Python function definition, which is lost in OpSchema. - self._param_schemas = param_schemas_from_function_ir(self.function_ir) + self._param_schemas = _param_schemas_from_function_ir(self.function_ir) return self._param_schemas def to_function_proto(self) -> onnx.FunctionProto: @@ -612,10 +653,29 @@ def op_schema(self) -> Optional[onnx.defs.OpSchema]: return self._op_schema # FIXME(justinchuby): outputs are empty. Need to fix. - self._op_schema = op_schema_from_function_ir(self.function_ir, self._opset) + self._op_schema = _op_schema_from_function_ir(self.function_ir, self._opset) return self._op_schema + @property + def signature(self) -> Optional[_schemas.OpSignature]: + """Returns the signature of this op.""" + if self._signature is not None: + return self._signature + + if self.op_schema is None: + return None + + self._signature = _schemas.OpSignature.from_function( + self.func, domain="_traced", name=self.name + ) + return self._signature + + @deprecation.deprecated( + since="0.1", + removed_in="the future", + instructions="use '.signature' instead", + ) def param_schemas(self) -> tuple[ParamSchema, ...]: """Returns the parameter schemas of this function.""" if self._param_schemas is not None: @@ -624,7 +684,7 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # NOTE: We generate the parameter schemas from the function_ir instead # of relying on the auto generated OpSchema because we need to preserve the keyword # argument order from the Python function definition, which is lost in OpSchema. - self._param_schemas = param_schemas_from_function_ir(self.function_ir) + self._param_schemas = _param_schemas_from_function_ir(self.function_ir) return self._param_schemas From ea90e0e50a1f6d133343787f10137f8966a5bc59 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 30 Aug 2024 22:54:07 +0000 Subject: [PATCH 2/8] Format --- onnxscript/_internal/_schemas.py | 35 ++++++++++----------------- onnxscript/_internal/_schemas_test.py | 14 ++++------- 2 files changed, 18 insertions(+), 31 deletions(-) diff --git a/onnxscript/_internal/_schemas.py b/onnxscript/_internal/_schemas.py index 551ea60cc..5f9fc4ce8 100644 --- a/onnxscript/_internal/_schemas.py +++ b/onnxscript/_internal/_schemas.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations import collections.abc @@ -17,6 +19,7 @@ ) import onnx + import onnxscript from onnxscript import ir @@ -281,9 +284,7 @@ def _get_allowed_types_from_type_annotation( allowed_types = set() if constraints := type_.__constraints__: for constraint in constraints: - allowed_types.update( - _get_allowed_types_from_type_annotation(constraint) - ) + allowed_types.update(_get_allowed_types_from_type_annotation(constraint)) else: bound = type_.__bound__ if bound is None: @@ -309,8 +310,8 @@ def _get_allowed_types_from_type_annotation( allowed_types = set() subtypes = typing.get_args(type_) for subtype in subtypes: - assert ( - subtype is not type(None) + assert subtype is not type( + None ), "Union should not contain None type because it is handled by _is_optional." allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) return allowed_types @@ -318,8 +319,7 @@ def _get_allowed_types_from_type_annotation( if isinstance(origin_type, type) and issubclass(origin_type, Sequence): subtypes = typing.get_args(type_) return { - ir.SequenceType(t) - for t in _get_allowed_types_from_type_annotation(subtypes[0]) + ir.SequenceType(t) for t in _get_allowed_types_from_type_annotation(subtypes[0]) } # Allow everything by default @@ -385,8 +385,7 @@ def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: constraint.type_param_str: TypeConstraintParam( name=constraint.type_param_str, allowed_types={ - _get_type_from_str(type_str) - for type_str in constraint.allowed_type_strs + _get_type_from_str(type_str) for type_str in constraint.allowed_type_strs }, description=constraint.description, ) @@ -394,8 +393,7 @@ def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: } params = [ - _convert_formal_parameter(param, type_constraints) - for param in op_schema.inputs + _convert_formal_parameter(param, type_constraints) for param in op_schema.inputs ] for param in op_schema.attributes.values(): @@ -417,8 +415,7 @@ def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: ) outputs = [ - _convert_formal_parameter(param, type_constraints) - for param in op_schema.outputs + _convert_formal_parameter(param, type_constraints) for param in op_schema.outputs ] return cls( @@ -451,9 +448,7 @@ def from_function( param.name, py_signature, ) - type_constraints[param.name] = TypeConstraintParam.any_value( - f"T_{param.name}" - ) + type_constraints[param.name] = TypeConstraintParam.any_value(f"T_{param.name}") else: type_ = type_hints[param.name] if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: @@ -488,9 +483,7 @@ def from_function( # 3. Otherwise, create a new TypeConstraintParam type_constraint = TypeConstraintParam( name=type_constraint_name, - allowed_types=_get_allowed_types_from_type_annotation( - type_ - ), + allowed_types=_get_allowed_types_from_type_annotation(type_), ) type_constraints[type_constraint_name] = type_constraint # 4. Create Parameter @@ -529,9 +522,7 @@ def from_function( return_param_name = f"TReturn{i}" type_constraint = TypeConstraintParam( name=return_param_name, - allowed_types=_get_allowed_types_from_type_annotation( - return_type_i - ), + allowed_types=_get_allowed_types_from_type_annotation(return_type_i), ) type_constraints[return_param_name] = type_constraint outputs.append( diff --git a/onnxscript/_internal/_schemas_test.py b/onnxscript/_internal/_schemas_test.py index 605407c7b..28979c03e 100644 --- a/onnxscript/_internal/_schemas_test.py +++ b/onnxscript/_internal/_schemas_test.py @@ -1,18 +1,16 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations import unittest from typing import Any, Optional, Sequence, TypeVar, Union +import parameterized + import onnxscript import onnxscript.testing -import parameterized from onnxscript import FLOAT, INT64, ir - -from torch_onnx import _schemas +from onnxscript._internal import _schemas _TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) _TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) @@ -146,9 +144,7 @@ class TypeConversionFunctionsTest(unittest.TestCase): ] ) def test_pytype_to_ir_type(self, _, pytype: Any, expected: set[ir.TypeProtocol]): - self.assertEqual( - _schemas._get_allowed_types_from_type_annotation(pytype), expected - ) + self.assertEqual(_schemas._get_allowed_types_from_type_annotation(pytype), expected) @parameterized.parameterized.expand( [ From 56c1bba175da55fa8c5041c8d87271cc2ba242b7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 30 Aug 2024 22:57:37 +0000 Subject: [PATCH 3/8] Move --- onnxscript/{_internal => ir}/_schemas.py | 0 onnxscript/{_internal => ir}/_schemas_test.py | 2 +- onnxscript/values.py | 3 ++- 3 files changed, 3 insertions(+), 2 deletions(-) rename onnxscript/{_internal => ir}/_schemas.py (100%) rename onnxscript/{_internal => ir}/_schemas_test.py (99%) diff --git a/onnxscript/_internal/_schemas.py b/onnxscript/ir/_schemas.py similarity index 100% rename from onnxscript/_internal/_schemas.py rename to onnxscript/ir/_schemas.py diff --git a/onnxscript/_internal/_schemas_test.py b/onnxscript/ir/_schemas_test.py similarity index 99% rename from onnxscript/_internal/_schemas_test.py rename to onnxscript/ir/_schemas_test.py index 28979c03e..58aa62c8f 100644 --- a/onnxscript/_internal/_schemas_test.py +++ b/onnxscript/ir/_schemas_test.py @@ -10,7 +10,7 @@ import onnxscript import onnxscript.testing from onnxscript import FLOAT, INT64, ir -from onnxscript._internal import _schemas +from onnxscript.ir import _schemas _TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) _TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) diff --git a/onnxscript/values.py b/onnxscript/values.py index ca54313b4..1f46075a1 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -24,7 +24,8 @@ from onnxscript import converter as converter_module from onnxscript import irbuilder, sourceinfo, type_annotation -from onnxscript._internal import _schemas, ast_utils, deprecation +from onnxscript._internal import ast_utils, deprecation +from onnxscript.ir import _schemas _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { onnx.defs.OpSchema.AttrType.FLOAT: float, From 0ca760427f51625d3442a6f6f2231a25f7b68c3b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 30 Aug 2024 15:58:50 -0700 Subject: [PATCH 4/8] Update onnxscript/ir/_schemas.py --- onnxscript/ir/_schemas.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index 5f9fc4ce8..335cc77ff 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -151,9 +151,6 @@ def _get_type_from_str( A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))". """ - - # TODO: Upstream this to IR - # Split the type_str a sequence types and dtypes # 1. Remove the ending ")" striped = type_str.rstrip(")") From ff1aaa5e5b42aaae0912cec0befc3f0ca050f0ca Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 31 Aug 2024 15:40:02 +0000 Subject: [PATCH 5/8] setter --- onnxscript/values.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/onnxscript/values.py b/onnxscript/values.py index 1f46075a1..2fd344ff4 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -260,6 +260,7 @@ def opset(self) -> Opset: ... @property def op_schema(self) -> Optional[onnx.defs.OpSchema]: ... + @property def signature(self) -> Optional[_schemas.OpSignature]: ... @@ -335,6 +336,10 @@ def signature(self) -> Optional[_schemas.OpSignature]: self._signature = _schemas.OpSignature.from_op_schema(self.op_schema) return self._signature + @signature.setter + def signature(self, value: _schemas.OpSignature): + self._signature = value + @deprecation.deprecated( since="0.1", removed_in="the future", @@ -528,6 +533,7 @@ def op_schema(self) -> Optional[onnx.defs.OpSchema]: return self._op_schema + @property def signature(self) -> Optional[_schemas.OpSignature]: """Returns the signature of this op.""" if self._signature is not None: @@ -541,6 +547,10 @@ def signature(self) -> Optional[_schemas.OpSignature]: ) return self._signature + @signature.setter + def signature(self, value: _schemas.OpSignature): + self._signature = value + def __getitem__(self, instance): """Returns a lambda to evaluate function using given evaluator instance. @@ -672,6 +682,10 @@ def signature(self) -> Optional[_schemas.OpSignature]: ) return self._signature + @signature.setter + def signature(self, value: _schemas.OpSignature): + self._signature = value + @deprecation.deprecated( since="0.1", removed_in="the future", From c10f5f5bf4931e174f55b262009fa4dac26d1a07 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 23 Oct 2024 23:15:32 +0000 Subject: [PATCH 6/8] update with the latest pytorch change --- onnxscript/ir/_schemas.py | 41 +++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index 335cc77ff..3422a0c28 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -8,15 +8,7 @@ import logging import types import typing -from typing import ( - Any, - Iterator, - Mapping, - Optional, - Sequence, - TypeVar, - Union, -) +from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union import onnx @@ -103,7 +95,7 @@ def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam: @classmethod def any_value(cls, name: str, description: str = "") -> TypeConstraintParam: - return cls(name, _ALL_VALUE_TYPES, description) # type: ignore + return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type] @dataclasses.dataclass(frozen=True) @@ -129,6 +121,8 @@ def has_default(self) -> bool: @dataclasses.dataclass(frozen=True) class AttributeParameter: + """A parameter in the function signature that represents an ONNX attribute.""" + name: str type: ir.AttributeType required: bool @@ -147,7 +141,7 @@ def has_default(self) -> bool: def _get_type_from_str( type_str: str, ) -> ir.TensorType | ir.SequenceType | ir.OptionalType: - """Converter a type_str from ONNX Opschema to ir.TypeProtocol. + """Converter a type_str from ONNX OpSchema to ir.TypeProtocol. A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))". """ @@ -180,14 +174,14 @@ def _convert_formal_parameter( param: onnx.defs.OpSchema.FormalParameter, type_constraints: Mapping[str, TypeConstraintParam], ) -> Parameter: - """Convert a formal parameter from ONNX Opschema to Parameter.""" + """Convert a formal parameter from ONNX OpSchema to Parameter.""" if param.type_str in type_constraints: type_constraint = type_constraints[param.type_str] else: # param.type_str can be a plain type like 'int64'. type_constraint = TypeConstraintParam( name=param.name, - allowed_types={_get_type_from_str(param.type_str)}, # type: ignore + allowed_types={_get_type_from_str(param.type_str)}, ) return Parameter( name=param.name, @@ -377,7 +371,7 @@ def __str__(self) -> str: @classmethod def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: - """Produce an OpSignature from an ONNX Opschema.""" + """Produce an OpSignature from an ONNX OpSchema.""" type_constraints = { constraint.type_param_str: TypeConstraintParam( name=constraint.type_param_str, @@ -434,7 +428,7 @@ def from_function( # https://github.com/python/cpython/issues/102405 type_hints = typing.get_type_hints(func) - params = [] + params: list[Parameter | AttributeParameter] = [] # Create a mapping from type to a unique name type_constraints: dict[str, TypeConstraintParam] = {} @@ -445,7 +439,20 @@ def from_function( param.name, py_signature, ) - type_constraints[param.name] = TypeConstraintParam.any_value(f"T_{param.name}") + type_constraint = TypeConstraintParam.any_value(f"T_{param.name}") + type_constraints[param.name] = type_constraint + params.append( + Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) + ) else: type_ = type_hints[param.name] if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: @@ -485,7 +492,7 @@ def from_function( type_constraints[type_constraint_name] = type_constraint # 4. Create Parameter params.append( - Parameter( # type: ignore[arg-type] + Parameter( name=param.name, type_constraint=type_constraint, required=param.default is inspect.Parameter.empty, From e6c36dd4e425b21c3f2915f0eb3ebd339ca25c03 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 23 Oct 2024 23:17:39 +0000 Subject: [PATCH 7/8] lint --- onnxscript/ir/_schemas_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_schemas_test.py b/onnxscript/ir/_schemas_test.py index 58aa62c8f..c134bd7a6 100644 --- a/onnxscript/ir/_schemas_test.py +++ b/onnxscript/ir/_schemas_test.py @@ -144,7 +144,7 @@ class TypeConversionFunctionsTest(unittest.TestCase): ] ) def test_pytype_to_ir_type(self, _, pytype: Any, expected: set[ir.TypeProtocol]): - self.assertEqual(_schemas._get_allowed_types_from_type_annotation(pytype), expected) + self.assertEqual(_schemas._get_allowed_types_from_type_annotation(pytype), expected) # pylint: disable=protected-access @parameterized.parameterized.expand( [ @@ -169,7 +169,7 @@ def test_pytype_to_ir_type(self, _, pytype: Any, expected: set[ir.TypeProtocol]) ] ) def test_get_type_constraint_name(self, _: str, pytype: Any, expected: str | None): - self.assertEqual(_schemas._get_type_constraint_name(pytype), expected) + self.assertEqual(_schemas._get_type_constraint_name(pytype), expected) # pylint: disable=protected-access if __name__ == "__main__": From ce95d914431f23005c21e6dab1c8691bda3461bc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 23 Oct 2024 23:23:28 +0000 Subject: [PATCH 8/8] Warn once --- onnxscript/_internal/deprecation.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxscript/_internal/deprecation.py b/onnxscript/_internal/deprecation.py index 301565c8d..7bf18482a 100644 --- a/onnxscript/_internal/deprecation.py +++ b/onnxscript/_internal/deprecation.py @@ -12,6 +12,12 @@ T = TypeVar("T") +@functools.lru_cache(maxsize=1024) +def _warn_once(message: str): + """Issue a FutureWarning only once per message.""" + warnings.warn(message, category=FutureWarning, stacklevel=3) + + def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[T], T]: """Marks functions as deprecated. @@ -30,12 +36,10 @@ def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[T], def decorator(function): @functools.wraps(function) def wrapper(*args, **kwargs): - warnings.warn( + _warn_once( f"'{function.__module__}.{function.__qualname__}' " f"is deprecated in version {since} and will be " f"removed in {removed_in}. Please {instructions}.", - category=FutureWarning, - stacklevel=2, ) return function(*args, **kwargs)