diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 35de86c8..ab280a63 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,3 +1,4 @@ +use pyo3::types::PyTuple; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; @@ -27,8 +28,8 @@ impl PyTensor { } #[getter] - fn shape(&self) -> Vec { - self.0.dims().to_vec() + fn shape(&self, py: Python<'_>) -> PyObject { + PyTuple::new(py, self.0.dims()).to_object(py) } #[getter]