PyO3: Better shape handling (#1143)

* Negative and `*args` shape handling

* Rename to `PyShapeWithHole` + validate that only one hole exists

* Regenerate stubs

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
This commit is contained in:
Lukas Kreussel
2023-10-29 16:41:44 +01:00
committed by GitHub
parent 154c674a79
commit 174b208052
10 changed files with 181 additions and 50 deletions

View File

@ -16,26 +16,13 @@ extern crate accelerate_src;
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
mod shape;
use shape::{PyShape, PyShapeWithHole};
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[derive(Clone, Debug)]
struct PyShape(Vec<usize>);
impl<'source> pyo3::FromPyObject<'source> for PyShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
Ok(PyShape(dims))
}
}
impl From<PyShape> for ::candle::Shape {
fn from(val: PyShape) -> Self {
val.0.into()
}
}
#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
/// A `candle` tensor.
@ -684,25 +671,37 @@ impl PyTensor {
Ok(Self(tensor))
}
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Reshapes the tensor to the given shape.
/// &RETURNS&: Tensor
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
fn reshape(&self, shape: PyShapeWithHole) -> PyResult<Self> {
Ok(PyTensor(
self.0
.reshape(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
}
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Broadcasts the tensor to the given shape.
/// &RETURNS&: Tensor
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult<Self> {
Ok(PyTensor(
self.0
.broadcast_as(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
}
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Broadcasts the tensor to the given shape, adding new dimensions on the left.
/// &RETURNS&: Tensor
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult<Self> {
Ok(PyTensor(
self.0
.broadcast_left(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
}
#[pyo3(text_signature = "(self, dim:int)")]
@ -915,21 +914,21 @@ impl PyTensor {
}
if let Some(kwargs) = kwargs {
if let Some(any) = kwargs.get_item("dtype") {
if let Ok(Some(any)) = kwargs.get_item("dtype") {
handle_duplicates(
&mut dtype,
any.extract::<PyDType>(),
"cannot specify multiple dtypes",
)?;
}
if let Some(any) = kwargs.get_item("device") {
if let Ok(Some(any)) = kwargs.get_item("device") {
handle_duplicates(
&mut device,
any.extract::<PyDevice>(),
"cannot specify multiple devices",
)?;
}
if let Some(any) = kwargs.get_item("other") {
if let Ok(Some(any)) = kwargs.get_item("other") {
handle_duplicates(
&mut other,
any.extract::<PyTensor>(),
@ -1049,27 +1048,27 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
}
#[pyfunction]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
/// Creates a new tensor with random values.
/// &RETURNS&: Tensor
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
/// Creates a new tensor with random values from a normal distribution.
/// &RETURNS&: Tensor
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
/// Creates a new tensor filled with ones.
/// &RETURNS&: Tensor
fn ones(
@ -1083,12 +1082,12 @@ fn ones(
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
};
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
/// Creates a new tensor filled with zeros.
/// &RETURNS&: Tensor
fn zeros(
@ -1102,7 +1101,7 @@ fn zeros(
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
};
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}

99
candle-pyo3/src/shape.rs Normal file
View File

@ -0,0 +1,99 @@
use ::candle::Tensor;
use pyo3::prelude::*;
#[derive(Clone, Debug)]
/// Represents an absolute shape e.g. (1, 2, 3)
pub struct PyShape(Vec<usize>);
impl<'source> pyo3::FromPyObject<'source> for PyShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
));
}
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
Ok(PyShape(dims))
} else {
let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
Ok(PyShape(dims))
}
}
}
impl From<PyShape> for ::candle::Shape {
fn from(val: PyShape) -> Self {
val.0.into()
}
}
#[derive(Clone, Debug)]
/// Represents a shape with a hole in it e.g. (1, -1, 3)
pub struct PyShapeWithHole(Vec<isize>);
impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
));
}
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
let dims: Vec<isize> = if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
pyo3::FromPyObject::extract(first_element)?
} else {
pyo3::FromPyObject::extract(tuple)?
};
// Ensure we have only positive numbers and at most one "hole" (-1)
let negative_ones = dims.iter().filter(|&&x| x == -1).count();
let any_invalid_dimensions = dims.iter().any(|&x| x < -1 || x == 0);
if negative_ones > 1 || any_invalid_dimensions {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid dimension in shape: {:?}",
dims
)));
}
Ok(PyShapeWithHole(dims))
}
}
impl PyShapeWithHole {
/// Returns `true` if the shape is absolute e.g. (1, 2, 3)
pub fn is_absolute(&self) -> bool {
self.0.iter().all(|x| *x > 0)
}
/// Convert a relative shape to an absolute shape e.g. (1, -1) -> (1, 12)
pub fn to_absolute(&self, t: &Tensor) -> PyResult<PyShape> {
if self.is_absolute() {
return Ok(PyShape(
self.0.iter().map(|x| *x as usize).collect::<Vec<usize>>(),
));
}
let mut elements = t.elem_count();
let mut new_dims: Vec<usize> = vec![];
for dim in self.0.iter() {
if *dim > 0 {
new_dims.push(*dim as usize);
elements /= *dim as usize;
} else if *dim == -1 {
new_dims.push(elements);
} else {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid dimension in shape: {}",
dim
)));
}
}
Ok(PyShape(new_dims))
}
}