mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
PyO3: Add equal
and __richcmp__
to candle.Tensor
(#1099)
* add `equal` to tensor * add `__richcmp__` support for tensors and scalars * typo * more typos * Add `abs` + `candle.testing` * remove duplicated `broadcast_shape_binary_op` * `candle.i16` => `candle.i64` * `tensor.nelements` -> `tensor.nelement` * Cleanup `abs`
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
from candle.utils import cuda_is_available
|
||||
from candle.testing import assert_equal
|
||||
import pytest
|
||||
|
||||
|
||||
@ -77,6 +78,78 @@ def test_tensor_can_be_scliced_3d():
|
||||
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
||||
|
||||
|
||||
def assert_bool(t: Tensor, expected: bool):
|
||||
assert t.shape == ()
|
||||
assert str(t.dtype) == str(candle.u8)
|
||||
assert bool(t.values()) == expected
|
||||
|
||||
|
||||
def test_tensor_supports_equality_opperations_with_scalars():
|
||||
t = Tensor(42.0)
|
||||
|
||||
assert_bool(t == 42.0, True)
|
||||
assert_bool(t == 43.0, False)
|
||||
|
||||
assert_bool(t != 42.0, False)
|
||||
assert_bool(t != 43.0, True)
|
||||
|
||||
assert_bool(t > 41.0, True)
|
||||
assert_bool(t > 42.0, False)
|
||||
|
||||
assert_bool(t >= 41.0, True)
|
||||
assert_bool(t >= 42.0, True)
|
||||
|
||||
assert_bool(t < 43.0, True)
|
||||
assert_bool(t < 42.0, False)
|
||||
|
||||
assert_bool(t <= 43.0, True)
|
||||
assert_bool(t <= 42.0, True)
|
||||
|
||||
|
||||
def test_tensor_supports_equality_opperations_with_tensors():
|
||||
t = Tensor(42.0)
|
||||
same = Tensor(42.0)
|
||||
other = Tensor(43.0)
|
||||
|
||||
assert_bool(t == same, True)
|
||||
assert_bool(t == other, False)
|
||||
|
||||
assert_bool(t != same, False)
|
||||
assert_bool(t != other, True)
|
||||
|
||||
assert_bool(t > same, False)
|
||||
assert_bool(t > other, False)
|
||||
|
||||
assert_bool(t >= same, True)
|
||||
assert_bool(t >= other, False)
|
||||
|
||||
assert_bool(t < same, False)
|
||||
assert_bool(t < other, True)
|
||||
|
||||
assert_bool(t <= same, True)
|
||||
assert_bool(t <= other, True)
|
||||
|
||||
|
||||
def test_tensor_equality_opperations_can_broadcast():
|
||||
# Create a decoder attention mask as a test case
|
||||
# e.g.
|
||||
# [[1,0,0]
|
||||
# [1,1,0]
|
||||
# [1,1,1]]
|
||||
mask_cond = candle.Tensor([0, 1, 2])
|
||||
mask = mask_cond < (mask_cond + 1).reshape((3, 1))
|
||||
assert mask.shape == (3, 3)
|
||||
assert_equal(mask, Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8))
|
||||
|
||||
|
||||
def test_tensor_can_be_hashed():
|
||||
t = Tensor(42.0)
|
||||
other = Tensor(42.0)
|
||||
# Hash should represent a unique tensor
|
||||
assert hash(t) != hash(other)
|
||||
assert hash(t) == hash(t)
|
||||
|
||||
|
||||
def test_tensor_can_be_expanded_with_none():
|
||||
t = candle.rand((12, 12))
|
||||
|
||||
|
Reference in New Issue
Block a user