mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
PyO3: Add equal
and __richcmp__
to candle.Tensor
(#1099)
* add `equal` to tensor * add `__richcmp__` support for tensors and scalars * typo * more typos * Add `abs` + `candle.testing` * remove duplicated `broadcast_shape_binary_op` * `candle.i16` => `candle.i64` * `tensor.nelements` -> `tensor.nelement` * Cleanup `abs`
This commit is contained in:
@ -203,7 +203,7 @@ impl Shape {
|
|||||||
|
|
||||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||||
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||||
let lhs = self;
|
let lhs = self;
|
||||||
let lhs_dims = lhs.dims();
|
let lhs_dims = lhs.dims();
|
||||||
let rhs_dims = rhs.dims();
|
let rhs_dims = rhs.dims();
|
||||||
|
@ -53,3 +53,39 @@ class Tensor:
|
|||||||
Return a slice of a tensor.
|
Return a slice of a tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __eq__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __ne__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __lt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __le__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __gt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __ge__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
@ -124,16 +124,46 @@ class Tensor:
|
|||||||
Add a scalar to a tensor or two tensors together.
|
Add a scalar to a tensor or two tensors together.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
|
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Return a slice of a tensor.
|
Return a slice of a tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Multiply a tensor by a scalar or one tensor by another.
|
Multiply a tensor by a scalar or one tensor by another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
|
"""
|
||||||
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Add a scalar to a tensor or two tensors together.
|
Add a scalar to a tensor or two tensors together.
|
||||||
@ -159,6 +189,11 @@ class Tensor:
|
|||||||
Divide a tensor by a scalar or one tensor by another.
|
Divide a tensor by a scalar or one tensor by another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
def abs(self) -> Tensor:
|
||||||
|
"""
|
||||||
|
Performs the `abs` operation on the tensor.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
def argmax_keepdim(self, dim: int) -> Tensor:
|
def argmax_keepdim(self, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the indices of the maximum value(s) across the selected dimension.
|
Returns the indices of the maximum value(s) across the selected dimension.
|
||||||
@ -308,6 +343,12 @@ class Tensor:
|
|||||||
ranges from `start` to `start + len`.
|
ranges from `start` to `start + len`.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
@property
|
||||||
|
def nelement(self) -> int:
|
||||||
|
"""
|
||||||
|
Gets the tensor's element count.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
def powf(self, p: float) -> Tensor:
|
def powf(self, p: float) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Performs the `pow` operation on the tensor with the given exponent.
|
Performs the `pow` operation on the tensor with the given exponent.
|
||||||
|
70
candle-pyo3/py_src/candle/testing/__init__.py
Normal file
70
candle-pyo3/py_src/candle/testing/__init__.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import candle
|
||||||
|
from candle import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_tensor_metadata(
|
||||||
|
actual: Tensor,
|
||||||
|
expected: Tensor,
|
||||||
|
check_device: bool = True,
|
||||||
|
check_dtype: bool = True,
|
||||||
|
check_layout: bool = True,
|
||||||
|
check_stride: bool = False,
|
||||||
|
):
|
||||||
|
if check_device:
|
||||||
|
assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"
|
||||||
|
|
||||||
|
if check_dtype:
|
||||||
|
assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"
|
||||||
|
|
||||||
|
if check_layout:
|
||||||
|
assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"
|
||||||
|
|
||||||
|
if check_stride:
|
||||||
|
assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_equal(
|
||||||
|
actual: Tensor,
|
||||||
|
expected: Tensor,
|
||||||
|
check_device: bool = True,
|
||||||
|
check_dtype: bool = True,
|
||||||
|
check_layout: bool = True,
|
||||||
|
check_stride: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Asserts that two tensors are exact equals.
|
||||||
|
"""
|
||||||
|
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
|
||||||
|
assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_almost_equal(
|
||||||
|
actual: Tensor,
|
||||||
|
expected: Tensor,
|
||||||
|
rtol=1e-05,
|
||||||
|
atol=1e-08,
|
||||||
|
check_device: bool = True,
|
||||||
|
check_dtype: bool = True,
|
||||||
|
check_layout: bool = True,
|
||||||
|
check_stride: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.
|
||||||
|
|
||||||
|
Computes: |actual - expected| ≤ atol + rtol x |expected|
|
||||||
|
"""
|
||||||
|
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
|
||||||
|
|
||||||
|
# Secure against overflow of u32 and u8 tensors
|
||||||
|
if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
|
||||||
|
actual = actual.to(candle.i64)
|
||||||
|
expected = expected.to(candle.i64)
|
||||||
|
|
||||||
|
diff = (actual - expected).abs()
|
||||||
|
|
||||||
|
threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)
|
||||||
|
|
||||||
|
assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great"
|
@ -1,8 +1,11 @@
|
|||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::pyclass::CompareOp;
|
||||||
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
|
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
|
||||||
use pyo3::ToPyObject;
|
use pyo3::ToPyObject;
|
||||||
|
use std::collections::hash_map::DefaultHasher;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
use std::os::raw::c_long;
|
use std::os::raw::c_long;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -132,9 +135,10 @@ macro_rules! pydtype {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pydtype!(i64, |v| v);
|
||||||
pydtype!(u8, |v| v);
|
pydtype!(u8, |v| v);
|
||||||
pydtype!(u32, |v| v);
|
pydtype!(u32, |v| v);
|
||||||
pydtype!(i64, |v| v);
|
|
||||||
pydtype!(f16, f32::from);
|
pydtype!(f16, f32::from);
|
||||||
pydtype!(bf16, f32::from);
|
pydtype!(bf16, f32::from);
|
||||||
pydtype!(f32, |v| v);
|
pydtype!(f32, |v| v);
|
||||||
@ -317,6 +321,13 @@ impl PyTensor {
|
|||||||
PyTuple::new(py, self.0.dims()).to_object(py)
|
PyTuple::new(py, self.0.dims()).to_object(py)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
/// Gets the tensor's element count.
|
||||||
|
/// &RETURNS&: int
|
||||||
|
fn nelement(&self) -> usize {
|
||||||
|
self.0.elem_count()
|
||||||
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
/// Gets the tensor's strides.
|
/// Gets the tensor's strides.
|
||||||
/// &RETURNS&: Tuple[int]
|
/// &RETURNS&: Tuple[int]
|
||||||
@ -353,6 +364,12 @@ impl PyTensor {
|
|||||||
self.__repr__()
|
self.__repr__()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Performs the `abs` operation on the tensor.
|
||||||
|
/// &RETURNS&: Tensor
|
||||||
|
fn abs(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.abs().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
/// Performs the `sin` operation on the tensor.
|
/// Performs the `sin` operation on the tensor.
|
||||||
/// &RETURNS&: Tensor
|
/// &RETURNS&: Tensor
|
||||||
fn sin(&self) -> PyResult<Self> {
|
fn sin(&self) -> PyResult<Self> {
|
||||||
@ -670,6 +687,58 @@ impl PyTensor {
|
|||||||
};
|
};
|
||||||
Ok(Self(tensor))
|
Ok(Self(tensor))
|
||||||
}
|
}
|
||||||
|
/// Rich-compare two tensors.
|
||||||
|
/// &RETURNS&: Tensor
|
||||||
|
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
|
||||||
|
let compare = |lhs: &Tensor, rhs: &Tensor| {
|
||||||
|
let t = match op {
|
||||||
|
CompareOp::Eq => lhs.eq(rhs),
|
||||||
|
CompareOp::Ne => lhs.ne(rhs),
|
||||||
|
CompareOp::Lt => lhs.lt(rhs),
|
||||||
|
CompareOp::Le => lhs.le(rhs),
|
||||||
|
CompareOp::Gt => lhs.gt(rhs),
|
||||||
|
CompareOp::Ge => lhs.ge(rhs),
|
||||||
|
};
|
||||||
|
Ok(PyTensor(t.map_err(wrap_err)?))
|
||||||
|
};
|
||||||
|
if let Ok(rhs) = rhs.extract::<PyTensor>() {
|
||||||
|
if self.0.shape() == rhs.0.shape() {
|
||||||
|
compare(&self.0, &rhs.0)
|
||||||
|
} else {
|
||||||
|
// We broadcast manually here because `candle.cmp` does not support automatic broadcasting
|
||||||
|
let broadcast_shape = self
|
||||||
|
.0
|
||||||
|
.shape()
|
||||||
|
.broadcast_shape_binary_op(rhs.0.shape(), "cmp")
|
||||||
|
.map_err(wrap_err)?;
|
||||||
|
let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
|
||||||
|
let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
|
||||||
|
|
||||||
|
compare(&broadcasted_lhs, &broadcasted_rhs)
|
||||||
|
}
|
||||||
|
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||||
|
let scalar_tensor = Tensor::new(rhs, self.0.device())
|
||||||
|
.map_err(wrap_err)?
|
||||||
|
.to_dtype(self.0.dtype())
|
||||||
|
.map_err(wrap_err)?
|
||||||
|
.broadcast_as(self.0.shape())
|
||||||
|
.map_err(wrap_err)?;
|
||||||
|
|
||||||
|
compare(&self.0, &scalar_tensor)
|
||||||
|
} else {
|
||||||
|
return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __hash__(&self) -> u64 {
|
||||||
|
// we have overridden __richcmp__ => py03 wants us to also override __hash__
|
||||||
|
// we simply hash the address of the tensor
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
let pointer = &self.0 as *const Tensor;
|
||||||
|
let address = pointer as usize;
|
||||||
|
address.hash(&mut hasher);
|
||||||
|
hasher.finish()
|
||||||
|
}
|
||||||
|
|
||||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||||
/// Reshapes the tensor to the given shape.
|
/// Reshapes the tensor to the given shape.
|
||||||
@ -1503,7 +1572,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add_class::<PyDType>()?;
|
m.add_class::<PyDType>()?;
|
||||||
m.add("u8", PyDType(DType::U8))?;
|
m.add("u8", PyDType(DType::U8))?;
|
||||||
m.add("u32", PyDType(DType::U32))?;
|
m.add("u32", PyDType(DType::U32))?;
|
||||||
m.add("i16", PyDType(DType::I64))?;
|
m.add("i64", PyDType(DType::I64))?;
|
||||||
m.add("bf16", PyDType(DType::BF16))?;
|
m.add("bf16", PyDType(DType::BF16))?;
|
||||||
m.add("f16", PyDType(DType::F16))?;
|
m.add("f16", PyDType(DType::F16))?;
|
||||||
m.add("f32", PyDType(DType::F32))?;
|
m.add("f32", PyDType(DType::F32))?;
|
||||||
|
33
candle-pyo3/tests/bindings/test_testing.py
Normal file
33
candle-pyo3/tests/bindings/test_testing.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import candle
|
||||||
|
from candle import Tensor
|
||||||
|
from candle.testing import assert_equal, assert_almost_equal
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
|
||||||
|
def test_assert_equal_asserts_correctly(dtype: candle.DType):
|
||||||
|
a = Tensor([1, 2, 3]).to(dtype)
|
||||||
|
b = Tensor([1, 2, 3]).to(dtype)
|
||||||
|
assert_equal(a, b)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
assert_equal(a, b + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
|
||||||
|
def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
|
||||||
|
a = Tensor([1, 2, 3]).to(dtype)
|
||||||
|
b = Tensor([1, 2, 3]).to(dtype)
|
||||||
|
assert_almost_equal(a, b)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
assert_almost_equal(a, b + 1)
|
||||||
|
|
||||||
|
assert_almost_equal(a, b + 1, atol=20)
|
||||||
|
assert_almost_equal(a, b + 1, rtol=20)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
assert_almost_equal(a, b + 1, atol=0.9)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
assert_almost_equal(a, b + 1, rtol=0.1)
|
@ -1,6 +1,7 @@
|
|||||||
import candle
|
import candle
|
||||||
from candle import Tensor
|
from candle import Tensor
|
||||||
from candle.utils import cuda_is_available
|
from candle.utils import cuda_is_available
|
||||||
|
from candle.testing import assert_equal
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@ -77,6 +78,78 @@ def test_tensor_can_be_scliced_3d():
|
|||||||
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
||||||
|
|
||||||
|
|
||||||
|
def assert_bool(t: Tensor, expected: bool):
|
||||||
|
assert t.shape == ()
|
||||||
|
assert str(t.dtype) == str(candle.u8)
|
||||||
|
assert bool(t.values()) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_supports_equality_opperations_with_scalars():
|
||||||
|
t = Tensor(42.0)
|
||||||
|
|
||||||
|
assert_bool(t == 42.0, True)
|
||||||
|
assert_bool(t == 43.0, False)
|
||||||
|
|
||||||
|
assert_bool(t != 42.0, False)
|
||||||
|
assert_bool(t != 43.0, True)
|
||||||
|
|
||||||
|
assert_bool(t > 41.0, True)
|
||||||
|
assert_bool(t > 42.0, False)
|
||||||
|
|
||||||
|
assert_bool(t >= 41.0, True)
|
||||||
|
assert_bool(t >= 42.0, True)
|
||||||
|
|
||||||
|
assert_bool(t < 43.0, True)
|
||||||
|
assert_bool(t < 42.0, False)
|
||||||
|
|
||||||
|
assert_bool(t <= 43.0, True)
|
||||||
|
assert_bool(t <= 42.0, True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_supports_equality_opperations_with_tensors():
|
||||||
|
t = Tensor(42.0)
|
||||||
|
same = Tensor(42.0)
|
||||||
|
other = Tensor(43.0)
|
||||||
|
|
||||||
|
assert_bool(t == same, True)
|
||||||
|
assert_bool(t == other, False)
|
||||||
|
|
||||||
|
assert_bool(t != same, False)
|
||||||
|
assert_bool(t != other, True)
|
||||||
|
|
||||||
|
assert_bool(t > same, False)
|
||||||
|
assert_bool(t > other, False)
|
||||||
|
|
||||||
|
assert_bool(t >= same, True)
|
||||||
|
assert_bool(t >= other, False)
|
||||||
|
|
||||||
|
assert_bool(t < same, False)
|
||||||
|
assert_bool(t < other, True)
|
||||||
|
|
||||||
|
assert_bool(t <= same, True)
|
||||||
|
assert_bool(t <= other, True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_equality_opperations_can_broadcast():
|
||||||
|
# Create a decoder attention mask as a test case
|
||||||
|
# e.g.
|
||||||
|
# [[1,0,0]
|
||||||
|
# [1,1,0]
|
||||||
|
# [1,1,1]]
|
||||||
|
mask_cond = candle.Tensor([0, 1, 2])
|
||||||
|
mask = mask_cond < (mask_cond + 1).reshape((3, 1))
|
||||||
|
assert mask.shape == (3, 3)
|
||||||
|
assert_equal(mask, Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8))
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_can_be_hashed():
|
||||||
|
t = Tensor(42.0)
|
||||||
|
other = Tensor(42.0)
|
||||||
|
# Hash should represent a unique tensor
|
||||||
|
assert hash(t) != hash(other)
|
||||||
|
assert hash(t) == hash(t)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_can_be_expanded_with_none():
|
def test_tensor_can_be_expanded_with_none():
|
||||||
t = candle.rand((12, 12))
|
t = candle.rand((12, 12))
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user