mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Merge pull request #53 from LaurentMazare/more-pyo3
Add more pyo3 wrapping
This commit is contained in:
@ -4,7 +4,7 @@ use pyo3::types::{PyString, PyTuple};
|
||||
|
||||
use half::{bf16, f16};
|
||||
|
||||
use ::candle::{DType, Device::Cpu, Tensor};
|
||||
use ::candle::{DType, Device::Cpu, Tensor, WithDType};
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
@ -22,34 +22,103 @@ impl std::ops::Deref for PyTensor {
|
||||
}
|
||||
}
|
||||
|
||||
trait PyDType: WithDType {
|
||||
fn to_py(&self, py: Python<'_>) -> PyObject;
|
||||
}
|
||||
|
||||
macro_rules! pydtype {
|
||||
($ty:ty, $conv:expr) => {
|
||||
impl PyDType for $ty {
|
||||
fn to_py(&self, py: Python<'_>) -> PyObject {
|
||||
$conv(*self).to_object(py)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
pydtype!(u8, |v| v);
|
||||
pydtype!(u32, |v| v);
|
||||
pydtype!(f16, f32::from);
|
||||
pydtype!(bf16, f32::from);
|
||||
pydtype!(f32, |v| v);
|
||||
pydtype!(f64, |v| v);
|
||||
|
||||
// TODO: Something similar to this should probably be a part of candle core.
|
||||
trait MapDType {
|
||||
type Output;
|
||||
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output>;
|
||||
|
||||
fn map(&self, t: &Tensor) -> PyResult<Self::Output> {
|
||||
match t.dtype() {
|
||||
DType::U8 => self.f::<u8>(t),
|
||||
DType::U32 => self.f::<u32>(t),
|
||||
DType::BF16 => self.f::<bf16>(t),
|
||||
DType::F16 => self.f::<f16>(t),
|
||||
DType::F32 => self.f::<f32>(t),
|
||||
DType::F64 => self.f::<f64>(t),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyTensor {
|
||||
#[new]
|
||||
// TODO: Handle arbitrary input dtype and shape.
|
||||
fn new(f: f32) -> PyResult<Self> {
|
||||
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
|
||||
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
|
||||
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
|
||||
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<f32>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
|
||||
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
||||
} else {
|
||||
Err(PyTypeError::new_err("incorrect type for tensor"))?
|
||||
};
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
/// Gets the tensor data as a Python value/array/array of array/...
|
||||
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
struct M<'a>(Python<'a>);
|
||||
impl<'a> MapDType for M<'a> {
|
||||
type Output = PyObject;
|
||||
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
|
||||
match t.rank() {
|
||||
0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)),
|
||||
1 => {
|
||||
let v = t.to_vec1::<T>().map_err(wrap_err)?;
|
||||
let v = v.iter().map(|v| v.to_py(self.0)).collect::<Vec<_>>();
|
||||
Ok(v.to_object(self.0))
|
||||
}
|
||||
2 => {
|
||||
let v = t.to_vec2::<T>().map_err(wrap_err)?;
|
||||
let v = v
|
||||
.iter()
|
||||
.map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
|
||||
.collect::<Vec<Vec<_>>>();
|
||||
Ok(v.to_object(self.0))
|
||||
}
|
||||
3 => {
|
||||
let v = t.to_vec3::<T>().map_err(wrap_err)?;
|
||||
let v = v
|
||||
.iter()
|
||||
.map(|v| {
|
||||
v.iter()
|
||||
.map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
|
||||
.collect()
|
||||
})
|
||||
.collect::<Vec<Vec<Vec<_>>>>();
|
||||
Ok(v.to_object(self.0))
|
||||
}
|
||||
n => Err(PyTypeError::new_err(format!(
|
||||
"TODO: conversion to PyObject is not handled for rank {n}"
|
||||
)))?,
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: Handle arbitrary shapes.
|
||||
let v = match self.0.dtype() {
|
||||
// TODO: Use the map bits to avoid enumerating the types.
|
||||
DType::U8 => self.to_scalar::<u8>().map_err(wrap_err)?.to_object(py),
|
||||
DType::U32 => self.to_scalar::<u32>().map_err(wrap_err)?.to_object(py),
|
||||
DType::F32 => self.to_scalar::<f32>().map_err(wrap_err)?.to_object(py),
|
||||
DType::F64 => self.to_scalar::<f64>().map_err(wrap_err)?.to_object(py),
|
||||
DType::BF16 => self
|
||||
.to_scalar::<bf16>()
|
||||
.map_err(wrap_err)?
|
||||
.to_f32()
|
||||
.to_object(py),
|
||||
DType::F16 => self
|
||||
.to_scalar::<f16>()
|
||||
.map_err(wrap_err)?
|
||||
.to_f32()
|
||||
.to_object(py),
|
||||
};
|
||||
Ok(v)
|
||||
M(py).map(self)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
@ -80,13 +149,23 @@ impl PyTensor {
|
||||
self.__repr__()
|
||||
}
|
||||
|
||||
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 + &rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 + rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
Err(PyTypeError::new_err("unsupported for add"))?
|
||||
Err(PyTypeError::new_err("unsupported rhs for add"))?
|
||||
};
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
@ -101,7 +180,7 @@ impl PyTensor {
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 * rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
Err(PyTypeError::new_err("unsupported for mul"))?
|
||||
Err(PyTypeError::new_err("unsupported rhs for mul"))?
|
||||
};
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
@ -109,17 +188,108 @@ impl PyTensor {
|
||||
fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
self.__mul__(rhs)
|
||||
}
|
||||
|
||||
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 - &rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 - rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
Err(PyTypeError::new_err("unsupported rhs for sub"))?
|
||||
};
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
// TODO: Add a PyShape type?
|
||||
fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn squeeze(&self, dim: usize) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn unsqueeze(&self, dim: usize) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn get(&self, index: usize) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn sum(&self, dims: Vec<usize>) -> PyResult<Self> {
|
||||
// TODO: Support a single dim as input?
|
||||
Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn sum_all(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn flatten_all(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn t(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.t().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn contiguous(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn is_contiguous(&self) -> bool {
|
||||
self.0.is_contiguous()
|
||||
}
|
||||
|
||||
fn is_fortran_contiguous(&self) -> bool {
|
||||
self.0.is_fortran_contiguous()
|
||||
}
|
||||
|
||||
fn detach(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn copy(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
|
||||
}
|
||||
}
|
||||
|
||||
/// Concatenate the tensors across one axis.
|
||||
#[pyfunction]
|
||||
fn cat(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
||||
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
||||
let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn add(tensor: &PyTensor, f: f64) -> PyResult<PyTensor> {
|
||||
let tensor = (&tensor.0 + f).map_err(wrap_err)?;
|
||||
fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
||||
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
||||
let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyTensor>()?;
|
||||
m.add_function(wrap_pyfunction!(add, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(stack, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -4,3 +4,9 @@ t = candle.Tensor(42.0)
|
||||
print(t)
|
||||
print("shape", t.shape, t.rank)
|
||||
print(t + t)
|
||||
|
||||
t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
|
||||
print(t)
|
||||
print(t+t)
|
||||
t = t.reshape([2, 4])
|
||||
print(t.matmul(t.t()))
|
||||
|
Reference in New Issue
Block a user