diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 75a52b4d..fd2890f6 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -17,3 +17,4 @@ crate-type = ["cdylib"] [dependencies] candle = { path = "../candle-core", default-features=false } pyo3 = { version = "0.19.0", features = ["extension-module"] } +half = { version = "2.3.1", features = ["num-traits"] } diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index ab280a63..ddf7b554 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,8 +1,10 @@ -use pyo3::types::PyTuple; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; +use pyo3::types::{PyString, PyTuple}; -use ::candle::{Device::Cpu, Tensor}; +use half::{bf16, f16}; + +use ::candle::{DType, Device::Cpu, Tensor}; pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) @@ -23,15 +25,48 @@ impl std::ops::Deref for PyTensor { #[pymethods] impl PyTensor { #[new] + // TODO: Handle arbitrary input dtype and shape. fn new(f: f32) -> PyResult { Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) } + fn values(&self, py: Python<'_>) -> PyResult { + // TODO: Handle arbitrary shapes. + let v = match self.0.dtype() { + // TODO: Use the map bits to avoid enumerating the types. + DType::U8 => self.to_scalar::().map_err(wrap_err)?.to_object(py), + DType::U32 => self.to_scalar::().map_err(wrap_err)?.to_object(py), + DType::F32 => self.to_scalar::().map_err(wrap_err)?.to_object(py), + DType::F64 => self.to_scalar::().map_err(wrap_err)?.to_object(py), + DType::BF16 => self + .to_scalar::() + .map_err(wrap_err)? + .to_f32() + .to_object(py), + DType::F16 => self + .to_scalar::() + .map_err(wrap_err)? + .to_f32() + .to_object(py), + }; + Ok(v) + } + #[getter] fn shape(&self, py: Python<'_>) -> PyObject { PyTuple::new(py, self.0.dims()).to_object(py) } + #[getter] + fn stride(&self, py: Python<'_>) -> PyObject { + PyTuple::new(py, self.0.stride()).to_object(py) + } + + #[getter] + fn dtype(&self, py: Python<'_>) -> PyObject { + PyString::new(py, self.0.dtype().as_str()).to_object(py) + } + #[getter] fn rank(&self) -> usize { self.0.rank() @@ -59,6 +94,21 @@ impl PyTensor { fn __radd__(&self, rhs: &PyAny) -> PyResult { self.__add__(rhs) } + + fn __mul__(&self, rhs: &PyAny) -> PyResult { + let tensor = if let Ok(rhs) = rhs.extract::() { + (&self.0 * &rhs.0).map_err(wrap_err)? + } else if let Ok(rhs) = rhs.extract::() { + (&self.0 * rhs).map_err(wrap_err)? + } else { + Err(PyTypeError::new_err("unsupported for mul"))? + }; + Ok(Self(tensor)) + } + + fn __rmul__(&self, rhs: &PyAny) -> PyResult { + self.__mul__(rhs) + } } #[pyfunction]