diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index be9e427d..e1ce7f97 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -25,9 +25,23 @@ impl PyTensor { Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) } + #[getter] + fn shape(&self) -> Vec { + self.0.dims().to_vec() + } + + #[getter] + fn rank(&self) -> usize { + self.0.rank() + } + fn __repr__(&self) -> String { format!("{}", self.0) } + + fn __str__(&self) -> String { + self.__repr__() + } } #[pyfunction] diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index b8a2818e..21242b44 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -2,4 +2,5 @@ import candle t = candle.Tensor(42.0) print(t) +print("shape", t.shape, t.rank) print(candle.add(t, 3.14))