mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
@ -6,7 +6,7 @@ use pyo3::prelude::*;
|
||||
pub struct PyShape(Vec<usize>);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for PyShape {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
|
||||
if ob.is_none() {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
|
||||
"Shape cannot be None",
|
||||
@ -16,10 +16,10 @@ impl<'source> pyo3::FromPyObject<'source> for PyShape {
|
||||
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
|
||||
if tuple.len() == 1 {
|
||||
let first_element = tuple.get_item(0)?;
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(&first_element)?;
|
||||
Ok(PyShape(dims))
|
||||
} else {
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(tuple)?;
|
||||
Ok(PyShape(dims))
|
||||
}
|
||||
}
|
||||
@ -36,7 +36,7 @@ impl From<PyShape> for ::candle::Shape {
|
||||
pub struct PyShapeWithHole(Vec<isize>);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
|
||||
if ob.is_none() {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
|
||||
"Shape cannot be None",
|
||||
@ -46,9 +46,9 @@ impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
|
||||
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
|
||||
let dims: Vec<isize> = if tuple.len() == 1 {
|
||||
let first_element = tuple.get_item(0)?;
|
||||
pyo3::FromPyObject::extract(first_element)?
|
||||
pyo3::FromPyObject::extract_bound(&first_element)?
|
||||
} else {
|
||||
pyo3::FromPyObject::extract(tuple)?
|
||||
pyo3::FromPyObject::extract_bound(tuple)?
|
||||
};
|
||||
|
||||
// Ensure we have only positive numbers and at most one "hole" (-1)
|
||||
|
Reference in New Issue
Block a user