Pyo3 dtype (#327)

* Better handling of dtypes in pyo3.

* More pyo3 dtype.
This commit is contained in:
Laurent Mazare
2023-08-06 10:17:43 +01:00
committed by GitHub
parent 88bd3b604a
commit 93cfe5642f
4 changed files with 61 additions and 22 deletions

View File

@ -40,21 +40,30 @@ impl std::ops::Deref for PyTensor {
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[pyclass(name = "DType")]
struct PyDType(DType);
impl<'source> FromPyObject<'source> for PyDType {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
use std::str::FromStr;
let dtype: &str = ob.extract()?;
let dtype = DType::from_str(dtype)
.map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
Ok(Self(dtype))
#[pymethods]
impl PyDType {
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
fn __str__(&self) -> String {
self.__repr__()
}
}
impl ToPyObject for PyDType {
fn to_object(&self, py: Python<'_>) -> PyObject {
self.0.as_str().to_object(py)
impl PyDType {
fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
use std::str::FromStr;
if let Ok(dtype) = ob.extract::<&str>(py) {
let dtype = DType::from_str(dtype)
.map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
Ok(Self(dtype))
} else {
ob.extract(py)
}
}
}
@ -223,8 +232,8 @@ impl PyTensor {
}
#[getter]
fn dtype(&self, py: Python<'_>) -> PyObject {
PyDType(self.0.dtype()).to_object(py)
fn dtype(&self) -> PyDType {
PyDType(self.0.dtype())
}
#[getter]
@ -367,7 +376,8 @@ impl PyTensor {
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
}
fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> {
let dtype = PyDType::from_pyobject(dtype, py)?;
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
}
@ -416,12 +426,15 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None))]
fn ones(
_py: Python<'_>,
py: Python<'_>,
shape: PyShape,
dtype: Option<PyDType>,
dtype: Option<PyObject>,
device: Option<PyDevice>,
) -> PyResult<PyTensor> {
let dtype = dtype.map_or(DType::F32, |dt| dt.0);
let dtype = match dtype {
None => DType::F32,
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
};
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
@ -430,12 +443,15 @@ fn ones(
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None))]
fn zeros(
_py: Python<'_>,
py: Python<'_>,
shape: PyShape,
dtype: Option<PyDType>,
dtype: Option<PyObject>,
device: Option<PyDevice>,
) -> PyResult<PyTensor> {
let dtype = dtype.map_or(DType::F32, |dt| dt.0);
let dtype = match dtype {
None => DType::F32,
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
};
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
@ -444,6 +460,13 @@ fn zeros(
#[pymodule]
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyTensor>()?;
m.add_class::<PyDType>()?;
m.add("u8", PyDType(DType::U8))?;
m.add("u32", PyDType(DType::U32))?;
m.add("bf16", PyDType(DType::BF16))?;
m.add("f16", PyDType(DType::F16))?;
m.add("f32", PyDType(DType::F32))?;
m.add("f64", PyDType(DType::F64))?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
m.add_function(wrap_pyfunction!(ones, m)?)?;
m.add_function(wrap_pyfunction!(rand, m)?)?;