From 78871ffe38a9ae0b6e4a905ab7d0329b7f3567c3 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 20:12:26 +0100 Subject: [PATCH] Add dtype support. --- candle-core/src/dtype.rs | 18 ++++++++++++++++++ candle-pyo3/src/lib.rs | 35 +++++++++++++++++++++++++++++------ candle-pyo3/test.py | 1 + 3 files changed, 48 insertions(+), 6 deletions(-) diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index e6785491..8ce70f64 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -10,6 +10,24 @@ pub enum DType { F64, } +#[derive(Debug, PartialEq, Eq)] +pub struct DTypeParseError; + +impl std::str::FromStr for DType { + type Err = DTypeParseError; + fn from_str(s: &str) -> std::result::Result { + match s { + "u8" => Ok(Self::U8), + "u32" => Ok(Self::U32), + "bf16" => Ok(Self::BF16), + "f16" => Ok(Self::F16), + "f32" => Ok(Self::F32), + "f64" => Ok(Self::F64), + _ => Err(DTypeParseError), + } + } +} + impl DType { pub fn as_str(&self) -> &'static str { match self { diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index b1504ada..7da91b3f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,6 +1,6 @@ use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{PyString, PyTuple}; +use pyo3::types::PyTuple; use half::{bf16, f16}; @@ -22,13 +22,32 @@ impl std::ops::Deref for PyTensor { } } -trait PyDType: WithDType { +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +struct PyDType(DType); + +impl<'source> FromPyObject<'source> for PyDType { + fn extract(ob: &'source PyAny) -> PyResult { + use std::str::FromStr; + let dtype: &str = ob.extract()?; + let dtype = DType::from_str(dtype) + .map_err(|_| PyTypeError::new_err(format!("invalid dtype {dtype}")))?; + Ok(Self(dtype)) + } +} + +impl ToPyObject for PyDType { + fn to_object(&self, py: Python<'_>) -> PyObject { + self.0.as_str().to_object(py) + } +} + +trait PyWithDType: WithDType { fn to_py(&self, py: Python<'_>) -> PyObject; } macro_rules! pydtype { ($ty:ty, $conv:expr) => { - impl PyDType for $ty { + impl PyWithDType for $ty { fn to_py(&self, py: Python<'_>) -> PyObject { $conv(*self).to_object(py) } @@ -45,7 +64,7 @@ pydtype!(f64, |v| v); // TODO: Something similar to this should probably be a part of candle core. trait MapDType { type Output; - fn f(&self, t: &Tensor) -> PyResult; + fn f(&self, t: &Tensor) -> PyResult; fn map(&self, t: &Tensor) -> PyResult { match t.dtype() { @@ -83,7 +102,7 @@ impl PyTensor { struct M<'a>(Python<'a>); impl<'a> MapDType for M<'a> { type Output = PyObject; - fn f(&self, t: &Tensor) -> PyResult { + fn f(&self, t: &Tensor) -> PyResult { match t.rank() { 0 => Ok(t.to_scalar::().map_err(wrap_err)?.to_py(self.0)), 1 => { @@ -133,7 +152,7 @@ impl PyTensor { #[getter] fn dtype(&self, py: Python<'_>) -> PyObject { - PyString::new(py, self.0.dtype().as_str()).to_object(py) + PyDType(self.0.dtype()).to_object(py) } #[getter] @@ -269,6 +288,10 @@ impl PyTensor { fn copy(&self) -> PyResult { Ok(PyTensor(self.0.copy().map_err(wrap_err)?)) } + + fn to_dtype(&self, dtype: PyDType) -> PyResult { + Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?)) + } } /// Concatenate the tensors across one axis. diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index d63f752b..1d792de5 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -10,3 +10,4 @@ print(t) print(t+t) t = t.reshape([2, 4]) print(t.matmul(t.t())) +print(t.to_dtype("u8"))