From 5b8c6764b05cfe82340101372549aa2a97a0ffbb Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 07:34:14 +0100 Subject: [PATCH] Add matmul/where_cond. --- candle-pyo3/src/lib.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 74847e5e..b1504ada 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -149,6 +149,16 @@ impl PyTensor { self.__repr__() } + fn matmul(&self, rhs: &Self) -> PyResult { + Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?)) + } + + fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult { + Ok(PyTensor( + self.0.where_cond(on_true, on_false).map_err(wrap_err)?, + )) + } + fn __add__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { (&self.0 + &rhs.0).map_err(wrap_err)? @@ -219,6 +229,15 @@ impl PyTensor { Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) } + fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult { + Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?)) + } + + fn sum(&self, dims: Vec) -> PyResult { + // TODO: Support a single dim as input? + Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?)) + } + fn sum_all(&self) -> PyResult { Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?)) }