mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
PyO3: Add equal
and __richcmp__
to candle.Tensor
(#1099)
* add `equal` to tensor * add `__richcmp__` support for tensors and scalars * typo * more typos * Add `abs` + `candle.testing` * remove duplicated `broadcast_shape_binary_op` * `candle.i16` => `candle.i64` * `tensor.nelements` -> `tensor.nelement` * Cleanup `abs`
This commit is contained in:
@ -124,16 +124,46 @@ class Tensor:
|
||||
Add a scalar to a tensor or two tensors together.
|
||||
"""
|
||||
pass
|
||||
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
|
||||
"""
|
||||
Return a slice of a tensor.
|
||||
"""
|
||||
pass
|
||||
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Multiply a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Add a scalar to a tensor or two tensors together.
|
||||
@ -159,6 +189,11 @@ class Tensor:
|
||||
Divide a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
def abs(self) -> Tensor:
|
||||
"""
|
||||
Performs the `abs` operation on the tensor.
|
||||
"""
|
||||
pass
|
||||
def argmax_keepdim(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Returns the indices of the maximum value(s) across the selected dimension.
|
||||
@ -308,6 +343,12 @@ class Tensor:
|
||||
ranges from `start` to `start + len`.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def nelement(self) -> int:
|
||||
"""
|
||||
Gets the tensor's element count.
|
||||
"""
|
||||
pass
|
||||
def powf(self, p: float) -> Tensor:
|
||||
"""
|
||||
Performs the `pow` operation on the tensor with the given exponent.
|
||||
|
70
candle-pyo3/py_src/candle/testing/__init__.py
Normal file
70
candle-pyo3/py_src/candle/testing/__init__.py
Normal file
@ -0,0 +1,70 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
|
||||
|
||||
_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])
|
||||
|
||||
|
||||
def _assert_tensor_metadata(
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_layout: bool = True,
|
||||
check_stride: bool = False,
|
||||
):
|
||||
if check_device:
|
||||
assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"
|
||||
|
||||
if check_dtype:
|
||||
assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"
|
||||
|
||||
if check_layout:
|
||||
assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"
|
||||
|
||||
if check_stride:
|
||||
assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"
|
||||
|
||||
|
||||
def assert_equal(
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_layout: bool = True,
|
||||
check_stride: bool = False,
|
||||
):
|
||||
"""
|
||||
Asserts that two tensors are exact equals.
|
||||
"""
|
||||
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
|
||||
assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"
|
||||
|
||||
|
||||
def assert_almost_equal(
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
rtol=1e-05,
|
||||
atol=1e-08,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_layout: bool = True,
|
||||
check_stride: bool = False,
|
||||
):
|
||||
"""
|
||||
Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.
|
||||
|
||||
Computes: |actual - expected| ≤ atol + rtol x |expected|
|
||||
"""
|
||||
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
|
||||
|
||||
# Secure against overflow of u32 and u8 tensors
|
||||
if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
|
||||
actual = actual.to(candle.i64)
|
||||
expected = expected.to(candle.i64)
|
||||
|
||||
diff = (actual - expected).abs()
|
||||
|
||||
threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)
|
||||
|
||||
assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great"
|
Reference in New Issue
Block a user