mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
PyO3: Better shape handling (#1143)
* Negative and `*args` shape handling * Rename to `PyShapeWithHole` + validate that only one hole exists * Regenerate stubs --------- Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
This commit is contained in:
@ -21,7 +21,7 @@ half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true }
|
||||
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
@ -19,10 +19,10 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
half = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
pyo3 = { version = "0.19.0", features = ["extension-module", "abi3-py38"] }
|
||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = "0.19"
|
||||
pyo3-build-config = "0.20"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
|
||||
|
||||
class bf16(DType):
|
||||
pass
|
||||
@ -26,21 +26,21 @@ class i64(DType):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
|
||||
def ones(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
|
||||
"""
|
||||
Creates a new tensor filled with ones.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
|
||||
def rand(*shape: Shape, device: Optional[Device] = None) -> Tensor:
|
||||
"""
|
||||
Creates a new tensor with random values.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
|
||||
def randn(*shape: Shape, device: Optional[Device] = None) -> Tensor:
|
||||
"""
|
||||
Creates a new tensor with random values from a normal distribution.
|
||||
"""
|
||||
@ -67,7 +67,7 @@ class u8(DType):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
|
||||
def zeros(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
|
||||
"""
|
||||
Creates a new tensor filled with zeros.
|
||||
"""
|
||||
@ -174,7 +174,7 @@ class Tensor:
|
||||
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
"""
|
||||
pass
|
||||
def broadcast_as(self, shape: Sequence[int]) -> Tensor:
|
||||
def broadcast_as(self, *shape: Shape) -> Tensor:
|
||||
"""
|
||||
Broadcasts the tensor to the given shape.
|
||||
"""
|
||||
@ -184,7 +184,7 @@ class Tensor:
|
||||
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
"""
|
||||
pass
|
||||
def broadcast_left(self, shape: Sequence[int]) -> Tensor:
|
||||
def broadcast_left(self, *shape: Shape) -> Tensor:
|
||||
"""
|
||||
Broadcasts the tensor to the given shape, adding new dimensions on the left.
|
||||
"""
|
||||
@ -329,7 +329,7 @@ class Tensor:
|
||||
Get the `recip` of the tensor.
|
||||
"""
|
||||
pass
|
||||
def reshape(self, shape: Sequence[int]) -> Tensor:
|
||||
def reshape(self, *shape: Shape) -> Tensor:
|
||||
"""
|
||||
Reshapes the tensor to the given shape.
|
||||
"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
@staticmethod
|
||||
|
@ -18,3 +18,5 @@ Device = TypeVar("Device", CPU, CUDA)
|
||||
Scalar = Union[int, float]
|
||||
|
||||
Index = Union[int, slice, None, "Ellipsis"]
|
||||
|
||||
Shape = Union[int, Sequence[int]]
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
@staticmethod
|
||||
|
@ -16,26 +16,13 @@ extern crate accelerate_src;
|
||||
|
||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||
|
||||
mod shape;
|
||||
use shape::{PyShape, PyShapeWithHole};
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct PyShape(Vec<usize>);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for PyShape {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
|
||||
Ok(PyShape(dims))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PyShape> for ::candle::Shape {
|
||||
fn from(val: PyShape) -> Self {
|
||||
val.0.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "Tensor")]
|
||||
/// A `candle` tensor.
|
||||
@ -684,25 +671,37 @@ impl PyTensor {
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||
/// Reshapes the tensor to the given shape.
|
||||
/// &RETURNS&: Tensor
|
||||
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
||||
fn reshape(&self, shape: PyShapeWithHole) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0
|
||||
.reshape(shape.to_absolute(&self.0)?)
|
||||
.map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||
/// Broadcasts the tensor to the given shape.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
|
||||
fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0
|
||||
.broadcast_as(shape.to_absolute(&self.0)?)
|
||||
.map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||
/// Broadcasts the tensor to the given shape, adding new dimensions on the left.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
|
||||
fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0
|
||||
.broadcast_left(shape.to_absolute(&self.0)?)
|
||||
.map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int)")]
|
||||
@ -915,21 +914,21 @@ impl PyTensor {
|
||||
}
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
if let Some(any) = kwargs.get_item("dtype") {
|
||||
if let Ok(Some(any)) = kwargs.get_item("dtype") {
|
||||
handle_duplicates(
|
||||
&mut dtype,
|
||||
any.extract::<PyDType>(),
|
||||
"cannot specify multiple dtypes",
|
||||
)?;
|
||||
}
|
||||
if let Some(any) = kwargs.get_item("device") {
|
||||
if let Ok(Some(any)) = kwargs.get_item("device") {
|
||||
handle_duplicates(
|
||||
&mut device,
|
||||
any.extract::<PyDevice>(),
|
||||
"cannot specify multiple devices",
|
||||
)?;
|
||||
}
|
||||
if let Some(any) = kwargs.get_item("other") {
|
||||
if let Ok(Some(any)) = kwargs.get_item("other") {
|
||||
handle_duplicates(
|
||||
&mut other,
|
||||
any.extract::<PyTensor>(),
|
||||
@ -1049,27 +1048,27 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor with random values.
|
||||
/// &RETURNS&: Tensor
|
||||
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor with random values from a normal distribution.
|
||||
/// &RETURNS&: Tensor
|
||||
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor filled with ones.
|
||||
/// &RETURNS&: Tensor
|
||||
fn ones(
|
||||
@ -1083,12 +1082,12 @@ fn ones(
|
||||
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
|
||||
};
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor filled with zeros.
|
||||
/// &RETURNS&: Tensor
|
||||
fn zeros(
|
||||
@ -1102,7 +1101,7 @@ fn zeros(
|
||||
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
|
||||
};
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
|
99
candle-pyo3/src/shape.rs
Normal file
99
candle-pyo3/src/shape.rs
Normal file
@ -0,0 +1,99 @@
|
||||
use ::candle::Tensor;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// Represents an absolute shape e.g. (1, 2, 3)
|
||||
pub struct PyShape(Vec<usize>);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for PyShape {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
if ob.is_none() {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
|
||||
"Shape cannot be None",
|
||||
));
|
||||
}
|
||||
|
||||
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
|
||||
if tuple.len() == 1 {
|
||||
let first_element = tuple.get_item(0)?;
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
|
||||
Ok(PyShape(dims))
|
||||
} else {
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
|
||||
Ok(PyShape(dims))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PyShape> for ::candle::Shape {
|
||||
fn from(val: PyShape) -> Self {
|
||||
val.0.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// Represents a shape with a hole in it e.g. (1, -1, 3)
|
||||
pub struct PyShapeWithHole(Vec<isize>);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
if ob.is_none() {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
|
||||
"Shape cannot be None",
|
||||
));
|
||||
}
|
||||
|
||||
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
|
||||
let dims: Vec<isize> = if tuple.len() == 1 {
|
||||
let first_element = tuple.get_item(0)?;
|
||||
pyo3::FromPyObject::extract(first_element)?
|
||||
} else {
|
||||
pyo3::FromPyObject::extract(tuple)?
|
||||
};
|
||||
|
||||
// Ensure we have only positive numbers and at most one "hole" (-1)
|
||||
let negative_ones = dims.iter().filter(|&&x| x == -1).count();
|
||||
let any_invalid_dimensions = dims.iter().any(|&x| x < -1 || x == 0);
|
||||
if negative_ones > 1 || any_invalid_dimensions {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||
"Invalid dimension in shape: {:?}",
|
||||
dims
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(PyShapeWithHole(dims))
|
||||
}
|
||||
}
|
||||
|
||||
impl PyShapeWithHole {
|
||||
/// Returns `true` if the shape is absolute e.g. (1, 2, 3)
|
||||
pub fn is_absolute(&self) -> bool {
|
||||
self.0.iter().all(|x| *x > 0)
|
||||
}
|
||||
|
||||
/// Convert a relative shape to an absolute shape e.g. (1, -1) -> (1, 12)
|
||||
pub fn to_absolute(&self, t: &Tensor) -> PyResult<PyShape> {
|
||||
if self.is_absolute() {
|
||||
return Ok(PyShape(
|
||||
self.0.iter().map(|x| *x as usize).collect::<Vec<usize>>(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut elements = t.elem_count();
|
||||
let mut new_dims: Vec<usize> = vec![];
|
||||
for dim in self.0.iter() {
|
||||
if *dim > 0 {
|
||||
new_dims.push(*dim as usize);
|
||||
elements /= *dim as usize;
|
||||
} else if *dim == -1 {
|
||||
new_dims.push(elements);
|
||||
} else {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||
"Invalid dimension in shape: {}",
|
||||
dim
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(PyShape(new_dims))
|
||||
}
|
||||
}
|
@ -13,7 +13,7 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
|
||||
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
"""
|
||||
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
|
||||
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index, Shape\n"
|
||||
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
|
||||
RETURN_TYPE_MARKER = "&RETURNS&: "
|
||||
ADDITIONAL_TYPEHINTS = {}
|
||||
|
31
candle-pyo3/tests/native/test_shape.py
Normal file
31
candle-pyo3/tests/native/test_shape.py
Normal file
@ -0,0 +1,31 @@
|
||||
from candle import Tensor
|
||||
from candle import rand
|
||||
import pytest
|
||||
|
||||
|
||||
def test_absolute_shapes_are_valid():
|
||||
a = rand((10, 20))
|
||||
assert a.shape == (10, 20)
|
||||
|
||||
b = rand(10, 20)
|
||||
assert b.shape == (10, 20)
|
||||
pytest.raises(OverflowError, lambda: rand((10, 20, -1)))
|
||||
pytest.raises(OverflowError, lambda: rand(-1, 20))
|
||||
pytest.raises(TypeError, lambda: rand("foo", True))
|
||||
|
||||
|
||||
def test_relative_shapes_are_valid():
|
||||
a = rand(10, 20)
|
||||
a = a.reshape((1, -1))
|
||||
assert a.shape == (1, 200)
|
||||
|
||||
b = rand(10, 20)
|
||||
b = b.reshape(-1, 1)
|
||||
assert b.shape == (200, 1)
|
||||
|
||||
c = rand(10, 20)
|
||||
pytest.raises(TypeError, lambda: c.reshape(1, "foo"))
|
||||
pytest.raises(ValueError, lambda: c.reshape(1, -2))
|
||||
pytest.raises(ValueError, lambda: c.reshape((-2, 1)))
|
||||
pytest.raises(ValueError, lambda: c.reshape((0, 1)))
|
||||
pytest.raises(ValueError, lambda: c.reshape((1, -1, -1)))
|
Reference in New Issue
Block a user