Skip to content

Commit

Permalink
refine more code
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Nov 26, 2024
1 parent bb40851 commit a7163f9
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 6 deletions.
3 changes: 1 addition & 2 deletions array_api_compat/paddle/_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ def _dtypes(self, kind):
int32 = paddle.int32
int64 = paddle.int64
uint8 = paddle.uint8
# uint16, uint32, and uint64 are present in newer versions of pytorch,
# but they aren't generally supported by the array API functions, so
# uint16, uint32, and uint64 are not fully supported in paddle,
# we omit them from this function.
float32 = paddle.float32
float64 = paddle.float64
Expand Down
5 changes: 2 additions & 3 deletions array_api_compat/paddle/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@
from ._aliases import matmul, matrix_transpose, tensordot

# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743

# first axis with size 3)

# paddle.cross also does not support broadcasting when it would add new
# dimensions https://github.com/pytorch/pytorch/issues/39656
# dimensions
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import warnings

# import jax
import jax
import numpy as np
import pytest
import torch
Expand Down

0 comments on commit a7163f9

Please sign in to comment.