Files
candle/candle-pyo3/src/lib.rs

153 lines
3.9 KiB
Rust

use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyString, PyTuple};
use half::{bf16, f16};
use ::candle::{DType, Device::Cpu, Tensor, WithDType};
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[derive(Clone)]
#[pyclass(name = "Tensor")]
struct PyTensor(Tensor);
impl std::ops::Deref for PyTensor {
type Target = Tensor;
fn deref(&self) -> &Self::Target {
&self.0
}
}
trait PyDType: WithDType {
fn to_py(&self, py: Python<'_>) -> PyObject;
}
macro_rules! pydtype {
($ty:ty, $conv:expr) => {
impl PyDType for $ty {
fn to_py(&self, py: Python<'_>) -> PyObject {
$conv(*self).to_object(py)
}
}
};
}
pydtype!(u8, |v| v);
pydtype!(u32, |v| v);
pydtype!(f16, f32::from);
pydtype!(bf16, f32::from);
pydtype!(f32, |v| v);
pydtype!(f64, |v| v);
// TODO: Something similar to this should probably be a part of candle core.
trait MapDType {
type Output;
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output>;
fn map(&self, t: &Tensor) -> PyResult<Self::Output> {
match t.dtype() {
DType::U8 => self.f::<u8>(t),
DType::U32 => self.f::<u32>(t),
DType::BF16 => self.f::<bf16>(t),
DType::F16 => self.f::<f16>(t),
DType::F32 => self.f::<f32>(t),
DType::F64 => self.f::<f64>(t),
}
}
}
#[pymethods]
impl PyTensor {
#[new]
// TODO: Handle arbitrary input dtype and shape.
fn new(f: f32) -> PyResult<Self> {
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
}
fn scalar(&self, py: Python<'_>) -> PyResult<PyObject> {
struct M<'a>(Python<'a>);
impl<'a> MapDType for M<'a> {
type Output = PyObject;
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0))
}
}
// TODO: Handle arbitrary shapes.
M(py).map(self)
}
#[getter]
fn shape(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.dims()).to_object(py)
}
#[getter]
fn stride(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.stride()).to_object(py)
}
#[getter]
fn dtype(&self, py: Python<'_>) -> PyObject {
PyString::new(py, self.0.dtype().as_str()).to_object(py)
}
#[getter]
fn rank(&self) -> usize {
self.0.rank()
}
fn __repr__(&self) -> String {
format!("{}", self.0)
}
fn __str__(&self) -> String {
self.__repr__()
}
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 + &rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 + rhs).map_err(wrap_err)?
} else {
Err(PyTypeError::new_err("unsupported for add"))?
};
Ok(Self(tensor))
}
fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> {
self.__add__(rhs)
}
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 * &rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 * rhs).map_err(wrap_err)?
} else {
Err(PyTypeError::new_err("unsupported for mul"))?
};
Ok(Self(tensor))
}
fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
self.__mul__(rhs)
}
}
#[pyfunction]
fn add(tensor: &PyTensor, f: f64) -> PyResult<PyTensor> {
let tensor = (&tensor.0 + f).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pymodule]
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyTensor>()?;
m.add_function(wrap_pyfunction!(add, m)?)?;
Ok(())
}