Add some binary operators.

This commit is contained in:
laurent
2023-07-01 21:27:35 +01:00
parent 42d1a52d01
commit fbbde5b02c
2 changed files with 18 additions and 2 deletions

View File

@ -1,4 +1,4 @@
use pyo3::exceptions::PyValueError;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use ::candle::{Device::Cpu, Tensor};
@ -7,6 +7,7 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[derive(Clone)]
#[pyclass(name = "Tensor")]
struct PyTensor(Tensor);
@ -42,6 +43,21 @@ impl PyTensor {
fn __str__(&self) -> String {
self.__repr__()
}
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 + &rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 + rhs).map_err(wrap_err)?
} else {
Err(PyTypeError::new_err("unsupported for add"))?
};
Ok(Self(tensor))
}
fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> {
self.__add__(rhs)
}
}
#[pyfunction]

View File

@ -3,4 +3,4 @@ import candle
t = candle.Tensor(42.0)
print(t)
print("shape", t.shape, t.rank)
print(candle.add(t, 3.14))
print(t + t)