Add more gradient tests + bugfixes. (#211)

* Add more gradient tests + bugfixes.

* More tests and fixes.

* More tests.
This commit is contained in:
Laurent Mazare
2023-07-21 07:52:39 +02:00
committed by GitHub
parent 4845d5cc64
commit c60831aad4
3 changed files with 60 additions and 4 deletions

View File

@ -1,5 +1,5 @@
use anyhow::{Context, Result};
use candle::{Device, Shape, Var};
use candle::{Device, Shape, Tensor, Var};
mod test_utils;
fn simple_grad(device: &Device) -> Result<()> {
@ -110,6 +110,61 @@ fn unary_grad(device: &Device) -> Result<()> {
grad_x.to_vec1::<f32>()?,
[806.8576, 14.778111, 5961.9155, 2.6997175]
);
let y = x.sin()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
y.to_vec1::<f32>()?,
[0.14112, 0.84147096, -0.7568025, 0.14943814],
);
assert_eq!(
grad_x.to_vec1::<f32>()?,
[-0.9899925, 0.5403023, -0.6536436, 0.9887711],
);
let y = x.cos()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
y.to_vec1::<f32>()?,
[-0.9899925, 0.5403023, -0.6536436, 0.9887711],
);
assert_eq!(
grad_x.to_vec1::<f32>()?,
[-0.14112, -0.84147096, 0.7568025, -0.14943814],
);
let y = x.sqr()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [9.0, 1.0, 16.0, 0.0225]);
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, 8.0, 0.3]);
let y = x.sqr()?.sqrt()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [3.0, 1.0, 4.0, 0.15]);
assert_eq!(grad_x.to_vec1::<f32>()?, [1.0, 1.0, 1.0, 1.0]);
let y = x.neg()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [-3.0, -1.0, -4.0, -0.15]);
assert_eq!(grad_x.to_vec1::<f32>()?, [-1.0, -1.0, -1.0, -1.0]);
let y = x.affine(0.2, 1.)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [1.6, 1.2, 1.8, 1.03]);
assert_eq!(grad_x.to_vec1::<f32>()?, [0.2, 0.2, 0.2, 0.2]);
let y = Tensor::new(1f32, device)?.broadcast_div(x)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
assert_eq!(
grad_x.to_vec1::<f32>()?,
[-0.11111111, -1.0, -0.0625, -44.444443],
);
let y = x.broadcast_div(&Tensor::new(0.5f32, device)?)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [6., 2., 8., 0.3]);
assert_eq!(grad_x.to_vec1::<f32>()?, [2., 2., 2., 2.]);
Ok(())
}