From fbbde5b02cc4a711aed609f887d1705d67b2fd20 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 1 Jul 2023 21:27:35 +0100 Subject: [PATCH] Add some binary operators. --- candle-pyo3/src/lib.rs | 18 +++++++++++++++++- candle-pyo3/test.py | 2 +- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index e1ce7f97..35de86c8 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,4 +1,4 @@ -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use ::candle::{Device::Cpu, Tensor}; @@ -7,6 +7,7 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) } +#[derive(Clone)] #[pyclass(name = "Tensor")] struct PyTensor(Tensor); @@ -42,6 +43,21 @@ impl PyTensor { fn __str__(&self) -> String { self.__repr__() } + + fn __add__(&self, rhs: &PyAny) -> PyResult { + let tensor = if let Ok(rhs) = rhs.extract::() { + (&self.0 + &rhs.0).map_err(wrap_err)? + } else if let Ok(rhs) = rhs.extract::() { + (&self.0 + rhs).map_err(wrap_err)? + } else { + Err(PyTypeError::new_err("unsupported for add"))? + }; + Ok(Self(tensor)) + } + + fn __radd__(&self, rhs: &PyAny) -> PyResult { + self.__add__(rhs) + } } #[pyfunction] diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 21242b44..0db3fab9 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -3,4 +3,4 @@ import candle t = candle.Tensor(42.0) print(t) print("shape", t.shape, t.rank) -print(candle.add(t, 3.14)) +print(t + t)