mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Always broadcast magic methods (#1101)
This commit is contained in:
@ -536,7 +536,7 @@ impl PyTensor {
|
||||
/// &RETURNS&: Tensor
|
||||
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 + &rhs.0).map_err(wrap_err)?
|
||||
self.0.broadcast_add(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 + rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
@ -553,7 +553,7 @@ impl PyTensor {
|
||||
/// &RETURNS&: Tensor
|
||||
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 * &rhs.0).map_err(wrap_err)?
|
||||
self.0.broadcast_mul(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 * rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
@ -570,7 +570,7 @@ impl PyTensor {
|
||||
/// &RETURNS&: Tensor
|
||||
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 - &rhs.0).map_err(wrap_err)?
|
||||
self.0.broadcast_sub(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 - rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
@ -583,7 +583,7 @@ impl PyTensor {
|
||||
/// &RETURNS&: Tensor
|
||||
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 / &rhs.0).map_err(wrap_err)?
|
||||
self.0.broadcast_div(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 / rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user