mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add a trait to avoid repeating the dtype matching.
This commit is contained in:
@ -4,7 +4,7 @@ use pyo3::types::{PyString, PyTuple};
|
||||
|
||||
use half::{bf16, f16};
|
||||
|
||||
use ::candle::{DType, Device::Cpu, Tensor};
|
||||
use ::candle::{DType, Device::Cpu, Tensor, WithDType};
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
@ -22,6 +22,43 @@ impl std::ops::Deref for PyTensor {
|
||||
}
|
||||
}
|
||||
|
||||
trait PyDType: WithDType {
|
||||
fn to_py(&self, py: Python<'_>) -> PyObject;
|
||||
}
|
||||
|
||||
macro_rules! pydtype {
|
||||
($ty:ty, $conv:expr) => {
|
||||
impl PyDType for $ty {
|
||||
fn to_py(&self, py: Python<'_>) -> PyObject {
|
||||
$conv(*self).to_object(py)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
pydtype!(u8, |v| v);
|
||||
pydtype!(u32, |v| v);
|
||||
pydtype!(f16, f32::from);
|
||||
pydtype!(bf16, f32::from);
|
||||
pydtype!(f32, |v| v);
|
||||
pydtype!(f64, |v| v);
|
||||
|
||||
// TODO: Something similar to this should probably be a part of candle core.
|
||||
trait MapDType {
|
||||
type Output;
|
||||
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output>;
|
||||
|
||||
fn map(&self, t: &Tensor) -> PyResult<Self::Output> {
|
||||
match t.dtype() {
|
||||
DType::U8 => self.f::<u8>(t),
|
||||
DType::U32 => self.f::<u32>(t),
|
||||
DType::BF16 => self.f::<bf16>(t),
|
||||
DType::F16 => self.f::<f16>(t),
|
||||
DType::F32 => self.f::<f32>(t),
|
||||
DType::F64 => self.f::<f64>(t),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyTensor {
|
||||
#[new]
|
||||
@ -30,26 +67,16 @@ impl PyTensor {
|
||||
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
fn scalar(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
struct M<'a>(Python<'a>);
|
||||
impl<'a> MapDType for M<'a> {
|
||||
type Output = PyObject;
|
||||
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
|
||||
Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0))
|
||||
}
|
||||
}
|
||||
// TODO: Handle arbitrary shapes.
|
||||
let v = match self.0.dtype() {
|
||||
// TODO: Use the map bits to avoid enumerating the types.
|
||||
DType::U8 => self.to_scalar::<u8>().map_err(wrap_err)?.to_object(py),
|
||||
DType::U32 => self.to_scalar::<u32>().map_err(wrap_err)?.to_object(py),
|
||||
DType::F32 => self.to_scalar::<f32>().map_err(wrap_err)?.to_object(py),
|
||||
DType::F64 => self.to_scalar::<f64>().map_err(wrap_err)?.to_object(py),
|
||||
DType::BF16 => self
|
||||
.to_scalar::<bf16>()
|
||||
.map_err(wrap_err)?
|
||||
.to_f32()
|
||||
.to_object(py),
|
||||
DType::F16 => self
|
||||
.to_scalar::<f16>()
|
||||
.map_err(wrap_err)?
|
||||
.to_f32()
|
||||
.to_object(py),
|
||||
};
|
||||
Ok(v)
|
||||
M(py).map(self)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
|
Reference in New Issue
Block a user