mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
More pyo3.
This commit is contained in:
@ -17,3 +17,4 @@ crate-type = ["cdylib"]
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", default-features=false }
|
candle = { path = "../candle-core", default-features=false }
|
||||||
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
||||||
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
use pyo3::types::PyTuple;
|
|
||||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::types::{PyString, PyTuple};
|
||||||
|
|
||||||
use ::candle::{Device::Cpu, Tensor};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
|
use ::candle::{DType, Device::Cpu, Tensor};
|
||||||
|
|
||||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||||
@ -23,15 +25,48 @@ impl std::ops::Deref for PyTensor {
|
|||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyTensor {
|
impl PyTensor {
|
||||||
#[new]
|
#[new]
|
||||||
|
// TODO: Handle arbitrary input dtype and shape.
|
||||||
fn new(f: f32) -> PyResult<Self> {
|
fn new(f: f32) -> PyResult<Self> {
|
||||||
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
|
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
|
// TODO: Handle arbitrary shapes.
|
||||||
|
let v = match self.0.dtype() {
|
||||||
|
// TODO: Use the map bits to avoid enumerating the types.
|
||||||
|
DType::U8 => self.to_scalar::<u8>().map_err(wrap_err)?.to_object(py),
|
||||||
|
DType::U32 => self.to_scalar::<u32>().map_err(wrap_err)?.to_object(py),
|
||||||
|
DType::F32 => self.to_scalar::<f32>().map_err(wrap_err)?.to_object(py),
|
||||||
|
DType::F64 => self.to_scalar::<f64>().map_err(wrap_err)?.to_object(py),
|
||||||
|
DType::BF16 => self
|
||||||
|
.to_scalar::<bf16>()
|
||||||
|
.map_err(wrap_err)?
|
||||||
|
.to_f32()
|
||||||
|
.to_object(py),
|
||||||
|
DType::F16 => self
|
||||||
|
.to_scalar::<f16>()
|
||||||
|
.map_err(wrap_err)?
|
||||||
|
.to_f32()
|
||||||
|
.to_object(py),
|
||||||
|
};
|
||||||
|
Ok(v)
|
||||||
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn shape(&self, py: Python<'_>) -> PyObject {
|
fn shape(&self, py: Python<'_>) -> PyObject {
|
||||||
PyTuple::new(py, self.0.dims()).to_object(py)
|
PyTuple::new(py, self.0.dims()).to_object(py)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn stride(&self, py: Python<'_>) -> PyObject {
|
||||||
|
PyTuple::new(py, self.0.stride()).to_object(py)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn dtype(&self, py: Python<'_>) -> PyObject {
|
||||||
|
PyString::new(py, self.0.dtype().as_str()).to_object(py)
|
||||||
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn rank(&self) -> usize {
|
fn rank(&self) -> usize {
|
||||||
self.0.rank()
|
self.0.rank()
|
||||||
@ -59,6 +94,21 @@ impl PyTensor {
|
|||||||
fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> {
|
fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||||
self.__add__(rhs)
|
self.__add__(rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||||
|
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||||
|
(&self.0 * &rhs.0).map_err(wrap_err)?
|
||||||
|
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||||
|
(&self.0 * rhs).map_err(wrap_err)?
|
||||||
|
} else {
|
||||||
|
Err(PyTypeError::new_err("unsupported for mul"))?
|
||||||
|
};
|
||||||
|
Ok(Self(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||||
|
self.__mul__(rhs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
Reference in New Issue
Block a user