PyO3: Add equal and __richcmp__ to candle.Tensor (#1099)

* add `equal` to tensor

* add `__richcmp__` support  for tensors and scalars

* typo

* more typos

* Add `abs` + `candle.testing`

* remove duplicated `broadcast_shape_binary_op`

* `candle.i16` => `candle.i64`

* `tensor.nelements` -> `tensor.nelement`

* Cleanup `abs`
This commit is contained in:
Lukas Kreussel
2023-10-30 16:17:28 +01:00
committed by GitHub
parent 969960847a
commit c05c0a8213
7 changed files with 325 additions and 3 deletions

View File

@ -124,16 +124,46 @@ class Tensor:
Add a scalar to a tensor or two tensors together.
"""
pass
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
@ -159,6 +189,11 @@ class Tensor:
Divide a tensor by a scalar or one tensor by another.
"""
pass
def abs(self) -> Tensor:
"""
Performs the `abs` operation on the tensor.
"""
pass
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
@ -308,6 +343,12 @@ class Tensor:
ranges from `start` to `start + len`.
"""
pass
@property
def nelement(self) -> int:
"""
Gets the tensor's element count.
"""
pass
def powf(self, p: float) -> Tensor:
"""
Performs the `pow` operation on the tensor with the given exponent.

View File

@ -0,0 +1,70 @@
import candle
from candle import Tensor
_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])
def _assert_tensor_metadata(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
if check_device:
assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"
if check_dtype:
assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"
if check_layout:
assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"
if check_stride:
assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"
def assert_equal(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts that two tensors are exact equals.
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"
def assert_almost_equal(
actual: Tensor,
expected: Tensor,
rtol=1e-05,
atol=1e-08,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.
Computes: |actual - expected| ≤ atol + rtol x |expected|
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
# Secure against overflow of u32 and u8 tensors
if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
actual = actual.to(candle.i64)
expected = expected.to(candle.i64)
diff = (actual - expected).abs()
threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)
assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great"