use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyString, PyTuple}; use half::{bf16, f16}; use ::candle::{DType, Device::Cpu, Tensor, WithDType}; pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) } #[derive(Clone)] #[pyclass(name = "Tensor")] struct PyTensor(Tensor); impl std::ops::Deref for PyTensor { type Target = Tensor; fn deref(&self) -> &Self::Target { &self.0 } } trait PyDType: WithDType { fn to_py(&self, py: Python<'_>) -> PyObject; } macro_rules! pydtype { ($ty:ty, $conv:expr) => { impl PyDType for $ty { fn to_py(&self, py: Python<'_>) -> PyObject { $conv(*self).to_object(py) } } }; } pydtype!(u8, |v| v); pydtype!(u32, |v| v); pydtype!(f16, f32::from); pydtype!(bf16, f32::from); pydtype!(f32, |v| v); pydtype!(f64, |v| v); // TODO: Something similar to this should probably be a part of candle core. trait MapDType { type Output; fn f(&self, t: &Tensor) -> PyResult; fn map(&self, t: &Tensor) -> PyResult { match t.dtype() { DType::U8 => self.f::(t), DType::U32 => self.f::(t), DType::BF16 => self.f::(t), DType::F16 => self.f::(t), DType::F32 => self.f::(t), DType::F64 => self.f::(t), } } } #[pymethods] impl PyTensor { #[new] // TODO: Handle arbitrary input dtype and shape. fn new(f: f32) -> PyResult { Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) } fn scalar(&self, py: Python<'_>) -> PyResult { struct M<'a>(Python<'a>); impl<'a> MapDType for M<'a> { type Output = PyObject; fn f(&self, t: &Tensor) -> PyResult { Ok(t.to_scalar::().map_err(wrap_err)?.to_py(self.0)) } } // TODO: Handle arbitrary shapes. M(py).map(self) } #[getter] fn shape(&self, py: Python<'_>) -> PyObject { PyTuple::new(py, self.0.dims()).to_object(py) } #[getter] fn stride(&self, py: Python<'_>) -> PyObject { PyTuple::new(py, self.0.stride()).to_object(py) } #[getter] fn dtype(&self, py: Python<'_>) -> PyObject { PyString::new(py, self.0.dtype().as_str()).to_object(py) } #[getter] fn rank(&self) -> usize { self.0.rank() } fn __repr__(&self) -> String { format!("{}", self.0) } fn __str__(&self) -> String { self.__repr__() } fn __add__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { (&self.0 + &rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { (&self.0 + rhs).map_err(wrap_err)? } else { Err(PyTypeError::new_err("unsupported for add"))? }; Ok(Self(tensor)) } fn __radd__(&self, rhs: &PyAny) -> PyResult { self.__add__(rhs) } fn __mul__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { (&self.0 * &rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { (&self.0 * rhs).map_err(wrap_err)? } else { Err(PyTypeError::new_err("unsupported for mul"))? }; Ok(Self(tensor)) } fn __rmul__(&self, rhs: &PyAny) -> PyResult { self.__mul__(rhs) } } #[pyfunction] fn add(tensor: &PyTensor, f: f64) -> PyResult { let tensor = (&tensor.0 + f).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pymodule] fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_function(wrap_pyfunction!(add, m)?)?; Ok(()) }