From 9a9858bbe00adc0be84a63df56d5d26078bf81a1 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 07:30:00 +0100 Subject: [PATCH] Expose a couple more ops. --- candle-pyo3/src/lib.rs | 87 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 82 insertions(+), 5 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index c85b41f0..74847e5e 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -155,7 +155,7 @@ impl PyTensor { } else if let Ok(rhs) = rhs.extract::() { (&self.0 + rhs).map_err(wrap_err)? } else { - Err(PyTypeError::new_err("unsupported for add"))? + Err(PyTypeError::new_err("unsupported rhs for add"))? }; Ok(Self(tensor)) } @@ -170,7 +170,7 @@ impl PyTensor { } else if let Ok(rhs) = rhs.extract::() { (&self.0 * rhs).map_err(wrap_err)? } else { - Err(PyTypeError::new_err("unsupported for mul"))? + Err(PyTypeError::new_err("unsupported rhs for mul"))? }; Ok(Self(tensor)) } @@ -179,21 +179,98 @@ impl PyTensor { self.__mul__(rhs) } + fn __sub__(&self, rhs: &PyAny) -> PyResult { + let tensor = if let Ok(rhs) = rhs.extract::() { + (&self.0 - &rhs.0).map_err(wrap_err)? + } else if let Ok(rhs) = rhs.extract::() { + (&self.0 - rhs).map_err(wrap_err)? + } else { + Err(PyTypeError::new_err("unsupported rhs for sub"))? + }; + Ok(Self(tensor)) + } + // TODO: Add a PyShape type? fn reshape(&self, shape: Vec) -> PyResult { Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) } + + fn broadcast_as(&self, shape: Vec) -> PyResult { + Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?)) + } + + fn broadcast_left(&self, shape: Vec) -> PyResult { + Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) + } + + fn squeeze(&self, dim: usize) -> PyResult { + Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?)) + } + + fn unsqueeze(&self, dim: usize) -> PyResult { + Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?)) + } + + fn get(&self, index: usize) -> PyResult { + Ok(PyTensor(self.0.get(index).map_err(wrap_err)?)) + } + + fn transpose(&self, dim1: usize, dim2: usize) -> PyResult { + Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) + } + + fn sum_all(&self) -> PyResult { + Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?)) + } + + fn flatten_all(&self) -> PyResult { + Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?)) + } + + fn t(&self) -> PyResult { + Ok(PyTensor(self.0.t().map_err(wrap_err)?)) + } + + fn contiguous(&self) -> PyResult { + Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?)) + } + + fn is_contiguous(&self) -> bool { + self.0.is_contiguous() + } + + fn is_fortran_contiguous(&self) -> bool { + self.0.is_fortran_contiguous() + } + + fn detach(&self) -> PyResult { + Ok(PyTensor(self.0.detach().map_err(wrap_err)?)) + } + + fn copy(&self) -> PyResult { + Ok(PyTensor(self.0.copy().map_err(wrap_err)?)) + } +} + +/// Concatenate the tensors across one axis. +#[pyfunction] +fn cat(tensors: Vec, dim: usize) -> PyResult { + let tensors = tensors.into_iter().map(|t| t.0).collect::>(); + let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?; + Ok(PyTensor(tensor)) } #[pyfunction] -fn add(tensor: &PyTensor, f: f64) -> PyResult { - let tensor = (&tensor.0 + f).map_err(wrap_err)?; +fn stack(tensors: Vec, dim: usize) -> PyResult { + let tensors = tensors.into_iter().map(|t| t.0).collect::>(); + let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pymodule] fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; - m.add_function(wrap_pyfunction!(add, m)?)?; + m.add_function(wrap_pyfunction!(cat, m)?)?; + m.add_function(wrap_pyfunction!(stack, m)?)?; Ok(()) }