mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add dtype support.
This commit is contained in:
@ -10,6 +10,24 @@ pub enum DType {
|
|||||||
F64,
|
F64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub struct DTypeParseError;
|
||||||
|
|
||||||
|
impl std::str::FromStr for DType {
|
||||||
|
type Err = DTypeParseError;
|
||||||
|
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||||
|
match s {
|
||||||
|
"u8" => Ok(Self::U8),
|
||||||
|
"u32" => Ok(Self::U32),
|
||||||
|
"bf16" => Ok(Self::BF16),
|
||||||
|
"f16" => Ok(Self::F16),
|
||||||
|
"f32" => Ok(Self::F32),
|
||||||
|
"f64" => Ok(Self::F64),
|
||||||
|
_ => Err(DTypeParseError),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl DType {
|
impl DType {
|
||||||
pub fn as_str(&self) -> &'static str {
|
pub fn as_str(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::{PyString, PyTuple};
|
use pyo3::types::PyTuple;
|
||||||
|
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
@ -22,13 +22,32 @@ impl std::ops::Deref for PyTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trait PyDType: WithDType {
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
struct PyDType(DType);
|
||||||
|
|
||||||
|
impl<'source> FromPyObject<'source> for PyDType {
|
||||||
|
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||||
|
use std::str::FromStr;
|
||||||
|
let dtype: &str = ob.extract()?;
|
||||||
|
let dtype = DType::from_str(dtype)
|
||||||
|
.map_err(|_| PyTypeError::new_err(format!("invalid dtype {dtype}")))?;
|
||||||
|
Ok(Self(dtype))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToPyObject for PyDType {
|
||||||
|
fn to_object(&self, py: Python<'_>) -> PyObject {
|
||||||
|
self.0.as_str().to_object(py)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trait PyWithDType: WithDType {
|
||||||
fn to_py(&self, py: Python<'_>) -> PyObject;
|
fn to_py(&self, py: Python<'_>) -> PyObject;
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! pydtype {
|
macro_rules! pydtype {
|
||||||
($ty:ty, $conv:expr) => {
|
($ty:ty, $conv:expr) => {
|
||||||
impl PyDType for $ty {
|
impl PyWithDType for $ty {
|
||||||
fn to_py(&self, py: Python<'_>) -> PyObject {
|
fn to_py(&self, py: Python<'_>) -> PyObject {
|
||||||
$conv(*self).to_object(py)
|
$conv(*self).to_object(py)
|
||||||
}
|
}
|
||||||
@ -45,7 +64,7 @@ pydtype!(f64, |v| v);
|
|||||||
// TODO: Something similar to this should probably be a part of candle core.
|
// TODO: Something similar to this should probably be a part of candle core.
|
||||||
trait MapDType {
|
trait MapDType {
|
||||||
type Output;
|
type Output;
|
||||||
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output>;
|
fn f<T: PyWithDType>(&self, t: &Tensor) -> PyResult<Self::Output>;
|
||||||
|
|
||||||
fn map(&self, t: &Tensor) -> PyResult<Self::Output> {
|
fn map(&self, t: &Tensor) -> PyResult<Self::Output> {
|
||||||
match t.dtype() {
|
match t.dtype() {
|
||||||
@ -83,7 +102,7 @@ impl PyTensor {
|
|||||||
struct M<'a>(Python<'a>);
|
struct M<'a>(Python<'a>);
|
||||||
impl<'a> MapDType for M<'a> {
|
impl<'a> MapDType for M<'a> {
|
||||||
type Output = PyObject;
|
type Output = PyObject;
|
||||||
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
|
fn f<T: PyWithDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
|
||||||
match t.rank() {
|
match t.rank() {
|
||||||
0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)),
|
0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)),
|
||||||
1 => {
|
1 => {
|
||||||
@ -133,7 +152,7 @@ impl PyTensor {
|
|||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn dtype(&self, py: Python<'_>) -> PyObject {
|
fn dtype(&self, py: Python<'_>) -> PyObject {
|
||||||
PyString::new(py, self.0.dtype().as_str()).to_object(py)
|
PyDType(self.0.dtype()).to_object(py)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
@ -269,6 +288,10 @@ impl PyTensor {
|
|||||||
fn copy(&self) -> PyResult<Self> {
|
fn copy(&self) -> PyResult<Self> {
|
||||||
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
|
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Concatenate the tensors across one axis.
|
/// Concatenate the tensors across one axis.
|
||||||
|
@ -10,3 +10,4 @@ print(t)
|
|||||||
print(t+t)
|
print(t+t)
|
||||||
t = t.reshape([2, 4])
|
t = t.reshape([2, 4])
|
||||||
print(t.matmul(t.t()))
|
print(t.matmul(t.t()))
|
||||||
|
print(t.to_dtype("u8"))
|
||||||
|
Reference in New Issue
Block a user