diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f23907dd..6108b613 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2126,3 +2126,69 @@ bin_trait!(Add, add, |_| 1., |v| v); bin_trait!(Sub, sub, |_| 1., |v: f64| -v); bin_trait!(Mul, mul, |v| v, |_| 0.); bin_trait!(Div, div, |v| 1. / v, |_| 0.); + +impl std::ops::Add for f64 { + type Output = Result; + + fn add(self, rhs: Tensor) -> Self::Output { + rhs + self + } +} + +impl std::ops::Add<&Tensor> for f64 { + type Output = Result; + + fn add(self, rhs: &Tensor) -> Self::Output { + rhs + self + } +} + +impl std::ops::Mul for f64 { + type Output = Result; + + fn mul(self, rhs: Tensor) -> Self::Output { + rhs * self + } +} + +impl std::ops::Mul<&Tensor> for f64 { + type Output = Result; + + fn mul(self, rhs: &Tensor) -> Self::Output { + rhs * self + } +} + +impl std::ops::Sub for f64 { + type Output = Result; + + fn sub(self, rhs: Tensor) -> Self::Output { + rhs.affine(-1., self) + } +} + +impl std::ops::Sub<&Tensor> for f64 { + type Output = Result; + + fn sub(self, rhs: &Tensor) -> Self::Output { + rhs.affine(-1., self) + } +} + +impl std::ops::Div for f64 { + type Output = Result; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: Tensor) -> Self::Output { + rhs.recip()? * self + } +} + +impl std::ops::Div<&Tensor> for f64 { + type Output = Result; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: &Tensor) -> Self::Output { + rhs.recip()? * self + } +}