PyO3: Better shape handling (#1143)

* Negative and `*args` shape handling

* Rename to `PyShapeWithHole` + validate that only one hole exists

* Regenerate stubs

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
This commit is contained in:
Lukas Kreussel
2023-10-29 16:41:44 +01:00
committed by GitHub
parent 154c674a79
commit 174b208052
10 changed files with 181 additions and 50 deletions

View File

@ -21,7 +21,7 @@ half = { workspace = true, optional = true }
image = { workspace = true } image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true } num-traits = { workspace = true }
pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true } pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true } rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }

View File

@ -19,10 +19,10 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" } candle-nn = { path = "../candle-nn", version = "0.3.0" }
half = { workspace = true } half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.19.0", features = ["extension-module", "abi3-py38"] } pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
[build-dependencies] [build-dependencies]
pyo3-build-config = "0.19" pyo3-build-config = "0.20"
[features] [features]
default = [] default = []

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT # Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
class bf16(DType): class bf16(DType):
pass pass
@ -26,21 +26,21 @@ class i64(DType):
pass pass
@staticmethod @staticmethod
def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: def ones(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
""" """
Creates a new tensor filled with ones. Creates a new tensor filled with ones.
""" """
pass pass
@staticmethod @staticmethod
def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor: def rand(*shape: Shape, device: Optional[Device] = None) -> Tensor:
""" """
Creates a new tensor with random values. Creates a new tensor with random values.
""" """
pass pass
@staticmethod @staticmethod
def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor: def randn(*shape: Shape, device: Optional[Device] = None) -> Tensor:
""" """
Creates a new tensor with random values from a normal distribution. Creates a new tensor with random values from a normal distribution.
""" """
@ -67,7 +67,7 @@ class u8(DType):
pass pass
@staticmethod @staticmethod
def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: def zeros(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
""" """
Creates a new tensor filled with zeros. Creates a new tensor filled with zeros.
""" """
@ -174,7 +174,7 @@ class Tensor:
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
""" """
pass pass
def broadcast_as(self, shape: Sequence[int]) -> Tensor: def broadcast_as(self, *shape: Shape) -> Tensor:
""" """
Broadcasts the tensor to the given shape. Broadcasts the tensor to the given shape.
""" """
@ -184,7 +184,7 @@ class Tensor:
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
""" """
pass pass
def broadcast_left(self, shape: Sequence[int]) -> Tensor: def broadcast_left(self, *shape: Shape) -> Tensor:
""" """
Broadcasts the tensor to the given shape, adding new dimensions on the left. Broadcasts the tensor to the given shape, adding new dimensions on the left.
""" """
@ -329,7 +329,7 @@ class Tensor:
Get the `recip` of the tensor. Get the `recip` of the tensor.
""" """
pass pass
def reshape(self, shape: Sequence[int]) -> Tensor: def reshape(self, *shape: Shape) -> Tensor:
""" """
Reshapes the tensor to the given shape. Reshapes the tensor to the given shape.
""" """

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT # Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor from candle import Tensor, DType, QTensor
@staticmethod @staticmethod

View File

@ -18,3 +18,5 @@ Device = TypeVar("Device", CPU, CUDA)
Scalar = Union[int, float] Scalar = Union[int, float]
Index = Union[int, slice, None, "Ellipsis"] Index = Union[int, slice, None, "Ellipsis"]
Shape = Union[int, Sequence[int]]

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT # Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor from candle import Tensor, DType, QTensor
@staticmethod @staticmethod

View File

@ -16,26 +16,13 @@ extern crate accelerate_src;
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
mod shape;
use shape::{PyShape, PyShapeWithHole};
pub fn wrap_err(err: ::candle::Error) -> PyErr { pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}")) PyErr::new::<PyValueError, _>(format!("{err:?}"))
} }
#[derive(Clone, Debug)]
struct PyShape(Vec<usize>);
impl<'source> pyo3::FromPyObject<'source> for PyShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
Ok(PyShape(dims))
}
}
impl From<PyShape> for ::candle::Shape {
fn from(val: PyShape) -> Self {
val.0.into()
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
#[pyclass(name = "Tensor")] #[pyclass(name = "Tensor")]
/// A `candle` tensor. /// A `candle` tensor.
@ -684,25 +671,37 @@ impl PyTensor {
Ok(Self(tensor)) Ok(Self(tensor))
} }
#[pyo3(text_signature = "(self, shape:Sequence[int])")] #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Reshapes the tensor to the given shape. /// Reshapes the tensor to the given shape.
/// &RETURNS&: Tensor /// &RETURNS&: Tensor
fn reshape(&self, shape: PyShape) -> PyResult<Self> { fn reshape(&self, shape: PyShapeWithHole) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) Ok(PyTensor(
self.0
.reshape(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
} }
#[pyo3(text_signature = "(self, shape:Sequence[int])")] #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Broadcasts the tensor to the given shape. /// Broadcasts the tensor to the given shape.
/// &RETURNS&: Tensor /// &RETURNS&: Tensor
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> { fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?)) Ok(PyTensor(
self.0
.broadcast_as(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
} }
#[pyo3(text_signature = "(self, shape:Sequence[int])")] #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Broadcasts the tensor to the given shape, adding new dimensions on the left. /// Broadcasts the tensor to the given shape, adding new dimensions on the left.
/// &RETURNS&: Tensor /// &RETURNS&: Tensor
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> { fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) Ok(PyTensor(
self.0
.broadcast_left(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
} }
#[pyo3(text_signature = "(self, dim:int)")] #[pyo3(text_signature = "(self, dim:int)")]
@ -915,21 +914,21 @@ impl PyTensor {
} }
if let Some(kwargs) = kwargs { if let Some(kwargs) = kwargs {
if let Some(any) = kwargs.get_item("dtype") { if let Ok(Some(any)) = kwargs.get_item("dtype") {
handle_duplicates( handle_duplicates(
&mut dtype, &mut dtype,
any.extract::<PyDType>(), any.extract::<PyDType>(),
"cannot specify multiple dtypes", "cannot specify multiple dtypes",
)?; )?;
} }
if let Some(any) = kwargs.get_item("device") { if let Ok(Some(any)) = kwargs.get_item("device") {
handle_duplicates( handle_duplicates(
&mut device, &mut device,
any.extract::<PyDevice>(), any.extract::<PyDevice>(),
"cannot specify multiple devices", "cannot specify multiple devices",
)?; )?;
} }
if let Some(any) = kwargs.get_item("other") { if let Ok(Some(any)) = kwargs.get_item("other") {
handle_duplicates( handle_duplicates(
&mut other, &mut other,
any.extract::<PyTensor>(), any.extract::<PyTensor>(),
@ -1049,27 +1048,27 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
} }
#[pyfunction] #[pyfunction]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] #[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
/// Creates a new tensor with random values. /// Creates a new tensor with random values.
/// &RETURNS&: Tensor /// &RETURNS&: Tensor
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor)) Ok(PyTensor(tensor))
} }
#[pyfunction] #[pyfunction]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] #[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
/// Creates a new tensor with random values from a normal distribution. /// Creates a new tensor with random values from a normal distribution.
/// &RETURNS&: Tensor /// &RETURNS&: Tensor
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor)) Ok(PyTensor(tensor))
} }
#[pyfunction] #[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] #[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
/// Creates a new tensor filled with ones. /// Creates a new tensor filled with ones.
/// &RETURNS&: Tensor /// &RETURNS&: Tensor
fn ones( fn ones(
@ -1083,12 +1082,12 @@ fn ones(
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
}; };
let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?; let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor)) Ok(PyTensor(tensor))
} }
#[pyfunction] #[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] #[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
/// Creates a new tensor filled with zeros. /// Creates a new tensor filled with zeros.
/// &RETURNS&: Tensor /// &RETURNS&: Tensor
fn zeros( fn zeros(
@ -1102,7 +1101,7 @@ fn zeros(
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
}; };
let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?; let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor)) Ok(PyTensor(tensor))
} }

99
candle-pyo3/src/shape.rs Normal file
View File

@ -0,0 +1,99 @@
use ::candle::Tensor;
use pyo3::prelude::*;
#[derive(Clone, Debug)]
/// Represents an absolute shape e.g. (1, 2, 3)
pub struct PyShape(Vec<usize>);
impl<'source> pyo3::FromPyObject<'source> for PyShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
));
}
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
Ok(PyShape(dims))
} else {
let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
Ok(PyShape(dims))
}
}
}
impl From<PyShape> for ::candle::Shape {
fn from(val: PyShape) -> Self {
val.0.into()
}
}
#[derive(Clone, Debug)]
/// Represents a shape with a hole in it e.g. (1, -1, 3)
pub struct PyShapeWithHole(Vec<isize>);
impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
));
}
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
let dims: Vec<isize> = if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
pyo3::FromPyObject::extract(first_element)?
} else {
pyo3::FromPyObject::extract(tuple)?
};
// Ensure we have only positive numbers and at most one "hole" (-1)
let negative_ones = dims.iter().filter(|&&x| x == -1).count();
let any_invalid_dimensions = dims.iter().any(|&x| x < -1 || x == 0);
if negative_ones > 1 || any_invalid_dimensions {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid dimension in shape: {:?}",
dims
)));
}
Ok(PyShapeWithHole(dims))
}
}
impl PyShapeWithHole {
/// Returns `true` if the shape is absolute e.g. (1, 2, 3)
pub fn is_absolute(&self) -> bool {
self.0.iter().all(|x| *x > 0)
}
/// Convert a relative shape to an absolute shape e.g. (1, -1) -> (1, 12)
pub fn to_absolute(&self, t: &Tensor) -> PyResult<PyShape> {
if self.is_absolute() {
return Ok(PyShape(
self.0.iter().map(|x| *x as usize).collect::<Vec<usize>>(),
));
}
let mut elements = t.elem_count();
let mut new_dims: Vec<usize> = vec![];
for dim in self.0.iter() {
if *dim > 0 {
new_dims.push(*dim as usize);
elements /= *dim as usize;
} else if *dim == -1 {
new_dims.push(elements);
} else {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid dimension in shape: {}",
dim
)));
}
}
Ok(PyShape(new_dims))
}
}

View File

@ -13,7 +13,7 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike from os import PathLike
""" """
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n" CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index, Shape\n"
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n" CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
RETURN_TYPE_MARKER = "&RETURNS&: " RETURN_TYPE_MARKER = "&RETURNS&: "
ADDITIONAL_TYPEHINTS = {} ADDITIONAL_TYPEHINTS = {}

View File

@ -0,0 +1,31 @@
from candle import Tensor
from candle import rand
import pytest
def test_absolute_shapes_are_valid():
a = rand((10, 20))
assert a.shape == (10, 20)
b = rand(10, 20)
assert b.shape == (10, 20)
pytest.raises(OverflowError, lambda: rand((10, 20, -1)))
pytest.raises(OverflowError, lambda: rand(-1, 20))
pytest.raises(TypeError, lambda: rand("foo", True))
def test_relative_shapes_are_valid():
a = rand(10, 20)
a = a.reshape((1, -1))
assert a.shape == (1, 200)
b = rand(10, 20)
b = b.reshape(-1, 1)
assert b.shape == (200, 1)
c = rand(10, 20)
pytest.raises(TypeError, lambda: c.reshape(1, "foo"))
pytest.raises(ValueError, lambda: c.reshape(1, -2))
pytest.raises(ValueError, lambda: c.reshape((-2, 1)))
pytest.raises(ValueError, lambda: c.reshape((0, 1)))
pytest.raises(ValueError, lambda: c.reshape((1, -1, -1)))