From f319583530745dfab125bd2d16c2dfa4aa75646d Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 21 Jun 2023 08:59:08 +0100 Subject: [PATCH] More QOL changes, binary op for constants. --- src/tensor.rs | 26 +++++++++++++++++++++----- tests/grad_tests.rs | 2 +- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index 816308e0..b8fa738a 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -445,7 +445,7 @@ impl Tensor { } macro_rules! bin_trait { - ($trait:ident, $fn1:ident) => { + ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => { impl> std::ops::$trait for Tensor { type Output = Result; @@ -477,10 +477,26 @@ macro_rules! bin_trait { Tensor::$fn1(&self, rhs?.borrow()) } } + + impl std::ops::$trait for Tensor { + type Output = Result; + + fn $fn1(self, rhs: f64) -> Self::Output { + self.affine($mul(rhs), $add(rhs)) + } + } + + impl std::ops::$trait for &Tensor { + type Output = Result; + + fn $fn1(self, rhs: f64) -> Self::Output { + self.affine($mul(rhs), $add(rhs)) + } + } }; } -bin_trait!(Add, add); -bin_trait!(Sub, sub); -bin_trait!(Mul, mul); -bin_trait!(Div, div); +bin_trait!(Add, add, |_| 1., |v| v); +bin_trait!(Sub, sub, |_| 1., |v: f64| -v); +bin_trait!(Mul, mul, |v| v, |_| 0.); +bin_trait!(Div, div, |v| 1. / v, |_| 0.); diff --git a/tests/grad_tests.rs b/tests/grad_tests.rs index 079b7952..e5ba68e8 100644 --- a/tests/grad_tests.rs +++ b/tests/grad_tests.rs @@ -4,7 +4,7 @@ use candle::{Device, Tensor}; #[test] fn simple_grad() -> Result<()> { let x = Tensor::var(&[3f32, 1., 4.], Device::Cpu)?; - let y = x.mul(&x)?.add(&x.affine(5., 0.)?)?.affine(1., 4.)?; + let y = (((&x * &x)? + &x * 5f64)? + 4f64)?; let grads = y.backward()?; let grad_x = grads.get(&x.id()).context("no grad for x")?; assert_eq!(x.to_vec1::()?, [3., 1., 4.]);