mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add some binary operators.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use ::candle::{Device::Cpu, Tensor};
|
||||
@ -7,6 +7,7 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[pyclass(name = "Tensor")]
|
||||
struct PyTensor(Tensor);
|
||||
|
||||
@ -42,6 +43,21 @@ impl PyTensor {
|
||||
fn __str__(&self) -> String {
|
||||
self.__repr__()
|
||||
}
|
||||
|
||||
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 + &rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 + rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
Err(PyTypeError::new_err("unsupported for add"))?
|
||||
};
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
self.__add__(rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
|
@ -3,4 +3,4 @@ import candle
|
||||
t = candle.Tensor(42.0)
|
||||
print(t)
|
||||
print("shape", t.shape, t.rank)
|
||||
print(candle.add(t, 3.14))
|
||||
print(t + t)
|
||||
|
Reference in New Issue
Block a user