PyO3: Add equal and __richcmp__ to candle.Tensor (#1099)

* add `equal` to tensor

* add `__richcmp__` support  for tensors and scalars

* typo

* more typos

* Add `abs` + `candle.testing`

* remove duplicated `broadcast_shape_binary_op`

* `candle.i16` => `candle.i64`

* `tensor.nelements` -> `tensor.nelement`

* Cleanup `abs`
This commit is contained in:
Lukas Kreussel
2023-10-30 16:17:28 +01:00
committed by GitHub
parent 969960847a
commit c05c0a8213
7 changed files with 325 additions and 3 deletions

View File

@ -1,8 +1,11 @@
#![allow(clippy::redundant_closure_call)]
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::os::raw::c_long;
use std::sync::Arc;
@ -132,9 +135,10 @@ macro_rules! pydtype {
}
};
}
pydtype!(i64, |v| v);
pydtype!(u8, |v| v);
pydtype!(u32, |v| v);
pydtype!(i64, |v| v);
pydtype!(f16, f32::from);
pydtype!(bf16, f32::from);
pydtype!(f32, |v| v);
@ -317,6 +321,13 @@ impl PyTensor {
PyTuple::new(py, self.0.dims()).to_object(py)
}
#[getter]
/// Gets the tensor's element count.
/// &RETURNS&: int
fn nelement(&self) -> usize {
self.0.elem_count()
}
#[getter]
/// Gets the tensor's strides.
/// &RETURNS&: Tuple[int]
@ -353,6 +364,12 @@ impl PyTensor {
self.__repr__()
}
/// Performs the `abs` operation on the tensor.
/// &RETURNS&: Tensor
fn abs(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.abs().map_err(wrap_err)?))
}
/// Performs the `sin` operation on the tensor.
/// &RETURNS&: Tensor
fn sin(&self) -> PyResult<Self> {
@ -670,6 +687,58 @@ impl PyTensor {
};
Ok(Self(tensor))
}
/// Rich-compare two tensors.
/// &RETURNS&: Tensor
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
let compare = |lhs: &Tensor, rhs: &Tensor| {
let t = match op {
CompareOp::Eq => lhs.eq(rhs),
CompareOp::Ne => lhs.ne(rhs),
CompareOp::Lt => lhs.lt(rhs),
CompareOp::Le => lhs.le(rhs),
CompareOp::Gt => lhs.gt(rhs),
CompareOp::Ge => lhs.ge(rhs),
};
Ok(PyTensor(t.map_err(wrap_err)?))
};
if let Ok(rhs) = rhs.extract::<PyTensor>() {
if self.0.shape() == rhs.0.shape() {
compare(&self.0, &rhs.0)
} else {
// We broadcast manually here because `candle.cmp` does not support automatic broadcasting
let broadcast_shape = self
.0
.shape()
.broadcast_shape_binary_op(rhs.0.shape(), "cmp")
.map_err(wrap_err)?;
let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
compare(&broadcasted_lhs, &broadcasted_rhs)
}
} else if let Ok(rhs) = rhs.extract::<f64>() {
let scalar_tensor = Tensor::new(rhs, self.0.device())
.map_err(wrap_err)?
.to_dtype(self.0.dtype())
.map_err(wrap_err)?
.broadcast_as(self.0.shape())
.map_err(wrap_err)?;
compare(&self.0, &scalar_tensor)
} else {
return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
}
}
fn __hash__(&self) -> u64 {
// we have overridden __richcmp__ => py03 wants us to also override __hash__
// we simply hash the address of the tensor
let mut hasher = DefaultHasher::new();
let pointer = &self.0 as *const Tensor;
let address = pointer as usize;
address.hash(&mut hasher);
hasher.finish()
}
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Reshapes the tensor to the given shape.
@ -1503,7 +1572,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDType>()?;
m.add("u8", PyDType(DType::U8))?;
m.add("u32", PyDType(DType::U32))?;
m.add("i16", PyDType(DType::I64))?;
m.add("i64", PyDType(DType::I64))?;
m.add("bf16", PyDType(DType::BF16))?;
m.add("f16", PyDType(DType::F16))?;
m.add("f32", PyDType(DType::F32))?;