Always broadcast magic methods (#1101)

This commit is contained in:
Lukas Kreussel
2023-10-17 11:57:12 +02:00
committed by GitHub
parent 2fe24ac5b1
commit b355ab4e2e
2 changed files with 77 additions and 4 deletions

View File

@ -536,7 +536,7 @@ impl PyTensor {
/// &RETURNS&: Tensor
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 + &rhs.0).map_err(wrap_err)?
self.0.broadcast_add(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 + rhs).map_err(wrap_err)?
} else {
@ -553,7 +553,7 @@ impl PyTensor {
/// &RETURNS&: Tensor
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 * &rhs.0).map_err(wrap_err)?
self.0.broadcast_mul(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 * rhs).map_err(wrap_err)?
} else {
@ -570,7 +570,7 @@ impl PyTensor {
/// &RETURNS&: Tensor
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 - &rhs.0).map_err(wrap_err)?
self.0.broadcast_sub(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 - rhs).map_err(wrap_err)?
} else {
@ -583,7 +583,7 @@ impl PyTensor {
/// &RETURNS&: Tensor
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 / &rhs.0).map_err(wrap_err)?
self.0.broadcast_div(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 / rhs).map_err(wrap_err)?
} else {