From f71f826f9c1d7e443a3fdad1ec5b0707bd45935a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Aug 2024 15:39:23 -0600 Subject: [PATCH 1/2] Use conj_physical for torch.conj torch.conj sets the conjugation bit, which breaks other libraries. Fixes #173. --- array_api_compat/torch/_aliases.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index c2be21fe..d0f4a5c9 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -145,6 +145,9 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: # Basic renames bitwise_invert = torch.bitwise_not newaxis = None +# torch.conj sets the conjugation bit, which breaks conversion to other +# libraries. See https://github.com/data-apis/array-api-compat/issues/173 +conj = torch.conj_physical # Two-arg elementwise functions # These require a wrapper to do the correct type promotion on 0-D tensors From b754016fb265b7d498562213f9915fa5f3a158d5 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 7 Aug 2024 12:50:17 -0600 Subject: [PATCH 2/2] Fix torch __all__ --- array_api_compat/torch/_aliases.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index d0f4a5c9..a6d642e2 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -707,18 +707,18 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - return torch.index_select(x, axis, indices, **kwargs) __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', - 'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', - 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', - 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', - 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', - 'remainder', 'subtract', 'max', 'min', 'clip', 'sort', 'prod', - 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', - 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', - 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', - 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take'] + 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', + 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', + 'bitwise_xor', 'copysign', 'divide', 'equal', 'floor_divide', + 'greater', 'greater_equal', 'less', 'less_equal', 'logaddexp', + 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', + 'min', 'clip', 'sort', 'prod', 'sum', 'any', 'all', 'mean', 'std', + 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', + 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', + 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', + 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', + 'UniqueInverseResult', 'unique_all', 'unique_counts', + 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', + 'vecdot', 'tensordot', 'isdtype', 'take'] _all_ignore = ['torch', 'get_xp']