Files
candle/candle-pyo3/tests/bindings/test_testing.py
Lukas Kreussel c05c0a8213 PyO3: Add equal and __richcmp__ to candle.Tensor (#1099)
* add `equal` to tensor

* add `__richcmp__` support  for tensors and scalars

* typo

* more typos

* Add `abs` + `candle.testing`

* remove duplicated `broadcast_shape_binary_op`

* `candle.i16` => `candle.i64`

* `tensor.nelements` -> `tensor.nelement`

* Cleanup `abs`
2023-10-30 15:17:28 +00:00

34 lines
1.1 KiB
Python

import candle
from candle import Tensor
from candle.testing import assert_equal, assert_almost_equal
import pytest
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
def test_assert_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_equal(a, b)
with pytest.raises(AssertionError):
assert_equal(a, b + 1)
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_almost_equal(a, b)
with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1)
assert_almost_equal(a, b + 1, atol=20)
assert_almost_equal(a, b + 1, rtol=20)
with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, atol=0.9)
with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, rtol=0.1)