From 42d1a52d01f5f10e3a04257cb4612225f08e1321 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 1 Jul 2023 20:55:15 +0100 Subject: [PATCH] Add two methods. --- candle-pyo3/src/lib.rs | 14 ++++++++++++++ candle-pyo3/test.py | 1 + 2 files changed, 15 insertions(+) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index be9e427d..e1ce7f97 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -25,9 +25,23 @@ impl PyTensor { Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) } + #[getter] + fn shape(&self) -> Vec { + self.0.dims().to_vec() + } + + #[getter] + fn rank(&self) -> usize { + self.0.rank() + } + fn __repr__(&self) -> String { format!("{}", self.0) } + + fn __str__(&self) -> String { + self.__repr__() + } } #[pyfunction] diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index b8a2818e..21242b44 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -2,4 +2,5 @@ import candle t = candle.Tensor(42.0) print(t) +print("shape", t.shape, t.rank) print(candle.add(t, 3.14))