mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
PyO3: Add equal
and __richcmp__
to candle.Tensor
(#1099)
* add `equal` to tensor * add `__richcmp__` support for tensors and scalars * typo * more typos * Add `abs` + `candle.testing` * remove duplicated `broadcast_shape_binary_op` * `candle.i16` => `candle.i64` * `tensor.nelements` -> `tensor.nelement` * Cleanup `abs`
This commit is contained in:
@ -1,8 +1,11 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyclass::CompareOp;
|
||||
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
|
||||
use pyo3::ToPyObject;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::os::raw::c_long;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -132,9 +135,10 @@ macro_rules! pydtype {
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pydtype!(i64, |v| v);
|
||||
pydtype!(u8, |v| v);
|
||||
pydtype!(u32, |v| v);
|
||||
pydtype!(i64, |v| v);
|
||||
pydtype!(f16, f32::from);
|
||||
pydtype!(bf16, f32::from);
|
||||
pydtype!(f32, |v| v);
|
||||
@ -317,6 +321,13 @@ impl PyTensor {
|
||||
PyTuple::new(py, self.0.dims()).to_object(py)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's element count.
|
||||
/// &RETURNS&: int
|
||||
fn nelement(&self) -> usize {
|
||||
self.0.elem_count()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's strides.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
@ -353,6 +364,12 @@ impl PyTensor {
|
||||
self.__repr__()
|
||||
}
|
||||
|
||||
/// Performs the `abs` operation on the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn abs(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.abs().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Performs the `sin` operation on the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn sin(&self) -> PyResult<Self> {
|
||||
@ -670,6 +687,58 @@ impl PyTensor {
|
||||
};
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
/// Rich-compare two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
|
||||
let compare = |lhs: &Tensor, rhs: &Tensor| {
|
||||
let t = match op {
|
||||
CompareOp::Eq => lhs.eq(rhs),
|
||||
CompareOp::Ne => lhs.ne(rhs),
|
||||
CompareOp::Lt => lhs.lt(rhs),
|
||||
CompareOp::Le => lhs.le(rhs),
|
||||
CompareOp::Gt => lhs.gt(rhs),
|
||||
CompareOp::Ge => lhs.ge(rhs),
|
||||
};
|
||||
Ok(PyTensor(t.map_err(wrap_err)?))
|
||||
};
|
||||
if let Ok(rhs) = rhs.extract::<PyTensor>() {
|
||||
if self.0.shape() == rhs.0.shape() {
|
||||
compare(&self.0, &rhs.0)
|
||||
} else {
|
||||
// We broadcast manually here because `candle.cmp` does not support automatic broadcasting
|
||||
let broadcast_shape = self
|
||||
.0
|
||||
.shape()
|
||||
.broadcast_shape_binary_op(rhs.0.shape(), "cmp")
|
||||
.map_err(wrap_err)?;
|
||||
let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
|
||||
let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
|
||||
|
||||
compare(&broadcasted_lhs, &broadcasted_rhs)
|
||||
}
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
let scalar_tensor = Tensor::new(rhs, self.0.device())
|
||||
.map_err(wrap_err)?
|
||||
.to_dtype(self.0.dtype())
|
||||
.map_err(wrap_err)?
|
||||
.broadcast_as(self.0.shape())
|
||||
.map_err(wrap_err)?;
|
||||
|
||||
compare(&self.0, &scalar_tensor)
|
||||
} else {
|
||||
return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
|
||||
}
|
||||
}
|
||||
|
||||
fn __hash__(&self) -> u64 {
|
||||
// we have overridden __richcmp__ => py03 wants us to also override __hash__
|
||||
// we simply hash the address of the tensor
|
||||
let mut hasher = DefaultHasher::new();
|
||||
let pointer = &self.0 as *const Tensor;
|
||||
let address = pointer as usize;
|
||||
address.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||
/// Reshapes the tensor to the given shape.
|
||||
@ -1503,7 +1572,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyDType>()?;
|
||||
m.add("u8", PyDType(DType::U8))?;
|
||||
m.add("u32", PyDType(DType::U32))?;
|
||||
m.add("i16", PyDType(DType::I64))?;
|
||||
m.add("i64", PyDType(DType::I64))?;
|
||||
m.add("bf16", PyDType(DType::BF16))?;
|
||||
m.add("f16", PyDType(DType::F16))?;
|
||||
m.add("f32", PyDType(DType::F32))?;
|
||||
|
Reference in New Issue
Block a user