diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 74847e5e..b1504ada 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -149,6 +149,16 @@ impl PyTensor { self.__repr__() } + fn matmul(&self, rhs: &Self) -> PyResult { + Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?)) + } + + fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult { + Ok(PyTensor( + self.0.where_cond(on_true, on_false).map_err(wrap_err)?, + )) + } + fn __add__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { (&self.0 + &rhs.0).map_err(wrap_err)? @@ -219,6 +229,15 @@ impl PyTensor { Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) } + fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult { + Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?)) + } + + fn sum(&self, dims: Vec) -> PyResult { + // TODO: Support a single dim as input? + Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?)) + } + fn sum_all(&self) -> PyResult { Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?)) }