From c60831aad4f266a320f9854f4ebb3d2d4ab8bc66 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 21 Jul 2023 07:52:39 +0200 Subject: [PATCH] Add more gradient tests + bugfixes. (#211) * Add more gradient tests + bugfixes. * More tests and fixes. * More tests. --- candle-core/src/backprop.rs | 5 +- candle-core/tests/grad_tests.rs | 57 ++++++++++++++++++- .../examples/simple-training/main.rs | 2 +- 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index cc4ffd49..4afaf23b 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -146,7 +146,7 @@ impl Tensor { *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?; } Op::WhereCond(pred, t, f) => { let zeros = grad.zeros_like()?; @@ -162,6 +162,7 @@ impl Tensor { let dim = *dim; let sum_grad = grads.or_insert(arg)?; // TODO: This is very very very inefficient, have some dedicated kernel for this. + // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add.html let indexes = indexes.to_vec1::()?; for (dst_index, src_index) in indexes.iter().enumerate() { let src_index = *src_index as usize; @@ -318,7 +319,7 @@ impl Tensor { *sum_grad = sum_grad.add(&arg_grad)? } Op::Unary(arg, UnaryOp::Sqrt) => { - let arg_grad = grad.div(arg)?.affine(0.5, 0.)?; + let arg_grad = grad.div(node)?.affine(0.5, 0.)?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 0ceab1de..6f30b5b7 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -1,5 +1,5 @@ use anyhow::{Context, Result}; -use candle::{Device, Shape, Var}; +use candle::{Device, Shape, Tensor, Var}; mod test_utils; fn simple_grad(device: &Device) -> Result<()> { @@ -110,6 +110,61 @@ fn unary_grad(device: &Device) -> Result<()> { grad_x.to_vec1::()?, [806.8576, 14.778111, 5961.9155, 2.6997175] ); + let y = x.sin()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!( + y.to_vec1::()?, + [0.14112, 0.84147096, -0.7568025, 0.14943814], + ); + assert_eq!( + grad_x.to_vec1::()?, + [-0.9899925, 0.5403023, -0.6536436, 0.9887711], + ); + let y = x.cos()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!( + y.to_vec1::()?, + [-0.9899925, 0.5403023, -0.6536436, 0.9887711], + ); + assert_eq!( + grad_x.to_vec1::()?, + [-0.14112, -0.84147096, 0.7568025, -0.14943814], + ); + let y = x.sqr()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [9.0, 1.0, 16.0, 0.0225]); + assert_eq!(grad_x.to_vec1::()?, [6.0, 2.0, 8.0, 0.3]); + let y = x.sqr()?.sqrt()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [3.0, 1.0, 4.0, 0.15]); + assert_eq!(grad_x.to_vec1::()?, [1.0, 1.0, 1.0, 1.0]); + let y = x.neg()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [-3.0, -1.0, -4.0, -0.15]); + assert_eq!(grad_x.to_vec1::()?, [-1.0, -1.0, -1.0, -1.0]); + let y = x.affine(0.2, 1.)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [1.6, 1.2, 1.8, 1.03]); + assert_eq!(grad_x.to_vec1::()?, [0.2, 0.2, 0.2, 0.2]); + let y = Tensor::new(1f32, device)?.broadcast_div(x)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [0.33333334, 1.0, 0.25, 6.6666665]); + assert_eq!( + grad_x.to_vec1::()?, + [-0.11111111, -1.0, -0.0625, -44.444443], + ); + let y = x.broadcast_div(&Tensor::new(0.5f32, device)?)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [6., 2., 8., 0.3]); + assert_eq!(grad_x.to_vec1::()?, [2., 2., 2., 2.]); Ok(()) } diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs index bf7385ac..767266f6 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/simple-training/main.rs @@ -40,7 +40,7 @@ pub fn main() -> Result<()> { let train_label_mask = Tensor::from_vec(train_label_mask, (train_labels.len(), LABELS), &dev)?; let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?; let bs = Var::zeros(LABELS, DType::F32, &dev)?; - let sgd = candle_nn::SGD::new(&[&ws, &bs], 3e-1); + let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0); let test_images = m.test_images; let test_labels = m.test_labels.to_vec1::()?; for epoch in 1..200 {