mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Implement a couple more binary ops. (#693)
This commit is contained in:
@ -2126,3 +2126,69 @@ bin_trait!(Add, add, |_| 1., |v| v);
|
||||
bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
|
||||
bin_trait!(Mul, mul, |v| v, |_| 0.);
|
||||
bin_trait!(Div, div, |v| 1. / v, |_| 0.);
|
||||
|
||||
impl std::ops::Add<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn add(self, rhs: Tensor) -> Self::Output {
|
||||
rhs + self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Add<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn add(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs + self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Mul<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn mul(self, rhs: Tensor) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Mul<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn mul(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn sub(self, rhs: Tensor) -> Self::Output {
|
||||
rhs.affine(-1., self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
fn sub(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs.affine(-1., self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Div<Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: Tensor) -> Self::Output {
|
||||
rhs.recip()? * self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Div<&Tensor> for f64 {
|
||||
type Output = Result<Tensor>;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: &Tensor) -> Self::Output {
|
||||
rhs.recip()? * self
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user