diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index ddf7b554..1d3e4efd 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -4,7 +4,7 @@ use pyo3::types::{PyString, PyTuple}; use half::{bf16, f16}; -use ::candle::{DType, Device::Cpu, Tensor}; +use ::candle::{DType, Device::Cpu, Tensor, WithDType}; pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) @@ -22,6 +22,43 @@ impl std::ops::Deref for PyTensor { } } +trait PyDType: WithDType { + fn to_py(&self, py: Python<'_>) -> PyObject; +} + +macro_rules! pydtype { + ($ty:ty, $conv:expr) => { + impl PyDType for $ty { + fn to_py(&self, py: Python<'_>) -> PyObject { + $conv(*self).to_object(py) + } + } + }; +} +pydtype!(u8, |v| v); +pydtype!(u32, |v| v); +pydtype!(f16, f32::from); +pydtype!(bf16, f32::from); +pydtype!(f32, |v| v); +pydtype!(f64, |v| v); + +// TODO: Something similar to this should probably be a part of candle core. +trait MapDType { + type Output; + fn f(&self, t: &Tensor) -> PyResult; + + fn map(&self, t: &Tensor) -> PyResult { + match t.dtype() { + DType::U8 => self.f::(t), + DType::U32 => self.f::(t), + DType::BF16 => self.f::(t), + DType::F16 => self.f::(t), + DType::F32 => self.f::(t), + DType::F64 => self.f::(t), + } + } +} + #[pymethods] impl PyTensor { #[new] @@ -30,26 +67,16 @@ impl PyTensor { Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) } - fn values(&self, py: Python<'_>) -> PyResult { + fn scalar(&self, py: Python<'_>) -> PyResult { + struct M<'a>(Python<'a>); + impl<'a> MapDType for M<'a> { + type Output = PyObject; + fn f(&self, t: &Tensor) -> PyResult { + Ok(t.to_scalar::().map_err(wrap_err)?.to_py(self.0)) + } + } // TODO: Handle arbitrary shapes. - let v = match self.0.dtype() { - // TODO: Use the map bits to avoid enumerating the types. - DType::U8 => self.to_scalar::().map_err(wrap_err)?.to_object(py), - DType::U32 => self.to_scalar::().map_err(wrap_err)?.to_object(py), - DType::F32 => self.to_scalar::().map_err(wrap_err)?.to_object(py), - DType::F64 => self.to_scalar::().map_err(wrap_err)?.to_object(py), - DType::BF16 => self - .to_scalar::() - .map_err(wrap_err)? - .to_f32() - .to_object(py), - DType::F16 => self - .to_scalar::() - .map_err(wrap_err)? - .to_f32() - .to_object(py), - }; - Ok(v) + M(py).map(self) } #[getter]