From 23db8a7da87ab4c42d86728d3129b6f5a345b6b9 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 21 Jun 2023 07:33:42 +0100 Subject: [PATCH] Simplify the test. --- tests/grad_tests.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/grad_tests.rs b/tests/grad_tests.rs index 4d7eae7d..079b7952 100644 --- a/tests/grad_tests.rs +++ b/tests/grad_tests.rs @@ -4,13 +4,12 @@ use candle::{Device, Tensor}; #[test] fn simple_grad() -> Result<()> { let x = Tensor::var(&[3f32, 1., 4.], Device::Cpu)?; - let five = Tensor::new(&[5f32, 5., 5.], Device::Cpu)?; - let y = x.mul(&x)?.add(&x.mul(&five)?)?.add(&five)?; + let y = x.mul(&x)?.add(&x.affine(5., 0.)?)?.affine(1., 4.)?; let grads = y.backward()?; let grad_x = grads.get(&x.id()).context("no grad for x")?; assert_eq!(x.to_vec1::()?, [3., 1., 4.]); - // y = x^2 + 5.x + 5 - assert_eq!(y.to_vec1::()?, [29., 11., 41.]); + // y = x^2 + 5.x + 4 + assert_eq!(y.to_vec1::()?, [28., 10., 40.]); // dy/dx = 2.x + 5 assert_eq!(grad_x.to_vec1::()?, [11., 7., 13.]); Ok(())