mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
153 lines
3.9 KiB
Rust
153 lines
3.9 KiB
Rust
use pyo3::exceptions::{PyTypeError, PyValueError};
|
|
use pyo3::prelude::*;
|
|
use pyo3::types::{PyString, PyTuple};
|
|
|
|
use half::{bf16, f16};
|
|
|
|
use ::candle::{DType, Device::Cpu, Tensor, WithDType};
|
|
|
|
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
|
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
#[pyclass(name = "Tensor")]
|
|
struct PyTensor(Tensor);
|
|
|
|
impl std::ops::Deref for PyTensor {
|
|
type Target = Tensor;
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
trait PyDType: WithDType {
|
|
fn to_py(&self, py: Python<'_>) -> PyObject;
|
|
}
|
|
|
|
macro_rules! pydtype {
|
|
($ty:ty, $conv:expr) => {
|
|
impl PyDType for $ty {
|
|
fn to_py(&self, py: Python<'_>) -> PyObject {
|
|
$conv(*self).to_object(py)
|
|
}
|
|
}
|
|
};
|
|
}
|
|
pydtype!(u8, |v| v);
|
|
pydtype!(u32, |v| v);
|
|
pydtype!(f16, f32::from);
|
|
pydtype!(bf16, f32::from);
|
|
pydtype!(f32, |v| v);
|
|
pydtype!(f64, |v| v);
|
|
|
|
// TODO: Something similar to this should probably be a part of candle core.
|
|
trait MapDType {
|
|
type Output;
|
|
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output>;
|
|
|
|
fn map(&self, t: &Tensor) -> PyResult<Self::Output> {
|
|
match t.dtype() {
|
|
DType::U8 => self.f::<u8>(t),
|
|
DType::U32 => self.f::<u32>(t),
|
|
DType::BF16 => self.f::<bf16>(t),
|
|
DType::F16 => self.f::<f16>(t),
|
|
DType::F32 => self.f::<f32>(t),
|
|
DType::F64 => self.f::<f64>(t),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[pymethods]
|
|
impl PyTensor {
|
|
#[new]
|
|
// TODO: Handle arbitrary input dtype and shape.
|
|
fn new(f: f32) -> PyResult<Self> {
|
|
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
|
|
}
|
|
|
|
fn scalar(&self, py: Python<'_>) -> PyResult<PyObject> {
|
|
struct M<'a>(Python<'a>);
|
|
impl<'a> MapDType for M<'a> {
|
|
type Output = PyObject;
|
|
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
|
|
Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0))
|
|
}
|
|
}
|
|
// TODO: Handle arbitrary shapes.
|
|
M(py).map(self)
|
|
}
|
|
|
|
#[getter]
|
|
fn shape(&self, py: Python<'_>) -> PyObject {
|
|
PyTuple::new(py, self.0.dims()).to_object(py)
|
|
}
|
|
|
|
#[getter]
|
|
fn stride(&self, py: Python<'_>) -> PyObject {
|
|
PyTuple::new(py, self.0.stride()).to_object(py)
|
|
}
|
|
|
|
#[getter]
|
|
fn dtype(&self, py: Python<'_>) -> PyObject {
|
|
PyString::new(py, self.0.dtype().as_str()).to_object(py)
|
|
}
|
|
|
|
#[getter]
|
|
fn rank(&self) -> usize {
|
|
self.0.rank()
|
|
}
|
|
|
|
fn __repr__(&self) -> String {
|
|
format!("{}", self.0)
|
|
}
|
|
|
|
fn __str__(&self) -> String {
|
|
self.__repr__()
|
|
}
|
|
|
|
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
|
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
|
(&self.0 + &rhs.0).map_err(wrap_err)?
|
|
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
|
(&self.0 + rhs).map_err(wrap_err)?
|
|
} else {
|
|
Err(PyTypeError::new_err("unsupported for add"))?
|
|
};
|
|
Ok(Self(tensor))
|
|
}
|
|
|
|
fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> {
|
|
self.__add__(rhs)
|
|
}
|
|
|
|
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
|
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
|
(&self.0 * &rhs.0).map_err(wrap_err)?
|
|
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
|
(&self.0 * rhs).map_err(wrap_err)?
|
|
} else {
|
|
Err(PyTypeError::new_err("unsupported for mul"))?
|
|
};
|
|
Ok(Self(tensor))
|
|
}
|
|
|
|
fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
|
self.__mul__(rhs)
|
|
}
|
|
}
|
|
|
|
#[pyfunction]
|
|
fn add(tensor: &PyTensor, f: f64) -> PyResult<PyTensor> {
|
|
let tensor = (&tensor.0 + f).map_err(wrap_err)?;
|
|
Ok(PyTensor(tensor))
|
|
}
|
|
|
|
#[pymodule]
|
|
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|
m.add_class::<PyTensor>()?;
|
|
m.add_function(wrap_pyfunction!(add, m)?)?;
|
|
Ok(())
|
|
}
|