mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
PyO3: Better shape handling (#1143)
* Negative and `*args` shape handling * Rename to `PyShapeWithHole` + validate that only one hole exists * Regenerate stubs --------- Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
This commit is contained in:
@ -16,26 +16,13 @@ extern crate accelerate_src;
|
||||
|
||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||
|
||||
mod shape;
|
||||
use shape::{PyShape, PyShapeWithHole};
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct PyShape(Vec<usize>);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for PyShape {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
|
||||
Ok(PyShape(dims))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PyShape> for ::candle::Shape {
|
||||
fn from(val: PyShape) -> Self {
|
||||
val.0.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "Tensor")]
|
||||
/// A `candle` tensor.
|
||||
@ -684,25 +671,37 @@ impl PyTensor {
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||
/// Reshapes the tensor to the given shape.
|
||||
/// &RETURNS&: Tensor
|
||||
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
||||
fn reshape(&self, shape: PyShapeWithHole) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0
|
||||
.reshape(shape.to_absolute(&self.0)?)
|
||||
.map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||
/// Broadcasts the tensor to the given shape.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
|
||||
fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0
|
||||
.broadcast_as(shape.to_absolute(&self.0)?)
|
||||
.map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||
/// Broadcasts the tensor to the given shape, adding new dimensions on the left.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
|
||||
fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0
|
||||
.broadcast_left(shape.to_absolute(&self.0)?)
|
||||
.map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int)")]
|
||||
@ -915,21 +914,21 @@ impl PyTensor {
|
||||
}
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
if let Some(any) = kwargs.get_item("dtype") {
|
||||
if let Ok(Some(any)) = kwargs.get_item("dtype") {
|
||||
handle_duplicates(
|
||||
&mut dtype,
|
||||
any.extract::<PyDType>(),
|
||||
"cannot specify multiple dtypes",
|
||||
)?;
|
||||
}
|
||||
if let Some(any) = kwargs.get_item("device") {
|
||||
if let Ok(Some(any)) = kwargs.get_item("device") {
|
||||
handle_duplicates(
|
||||
&mut device,
|
||||
any.extract::<PyDevice>(),
|
||||
"cannot specify multiple devices",
|
||||
)?;
|
||||
}
|
||||
if let Some(any) = kwargs.get_item("other") {
|
||||
if let Ok(Some(any)) = kwargs.get_item("other") {
|
||||
handle_duplicates(
|
||||
&mut other,
|
||||
any.extract::<PyTensor>(),
|
||||
@ -1049,27 +1048,27 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor with random values.
|
||||
/// &RETURNS&: Tensor
|
||||
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor with random values from a normal distribution.
|
||||
/// &RETURNS&: Tensor
|
||||
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor filled with ones.
|
||||
/// &RETURNS&: Tensor
|
||||
fn ones(
|
||||
@ -1083,12 +1082,12 @@ fn ones(
|
||||
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
|
||||
};
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor filled with zeros.
|
||||
/// &RETURNS&: Tensor
|
||||
fn zeros(
|
||||
@ -1102,7 +1101,7 @@ fn zeros(
|
||||
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
|
||||
};
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user