mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add more gradient tests + bugfixes. (#211)
* Add more gradient tests + bugfixes. * More tests and fixes. * More tests.
This commit is contained in:
@ -146,7 +146,7 @@ impl Tensor {
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
|
||||
}
|
||||
Op::WhereCond(pred, t, f) => {
|
||||
let zeros = grad.zeros_like()?;
|
||||
@ -162,6 +162,7 @@ impl Tensor {
|
||||
let dim = *dim;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// TODO: This is very very very inefficient, have some dedicated kernel for this.
|
||||
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add.html
|
||||
let indexes = indexes.to_vec1::<u32>()?;
|
||||
for (dst_index, src_index) in indexes.iter().enumerate() {
|
||||
let src_index = *src_index as usize;
|
||||
@ -318,7 +319,7 @@ impl Tensor {
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Sqrt) => {
|
||||
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
||||
let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use anyhow::{Context, Result};
|
||||
use candle::{Device, Shape, Var};
|
||||
use candle::{Device, Shape, Tensor, Var};
|
||||
mod test_utils;
|
||||
|
||||
fn simple_grad(device: &Device) -> Result<()> {
|
||||
@ -110,6 +110,61 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[806.8576, 14.778111, 5961.9155, 2.6997175]
|
||||
);
|
||||
let y = x.sin()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
y.to_vec1::<f32>()?,
|
||||
[0.14112, 0.84147096, -0.7568025, 0.14943814],
|
||||
);
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[-0.9899925, 0.5403023, -0.6536436, 0.9887711],
|
||||
);
|
||||
let y = x.cos()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
y.to_vec1::<f32>()?,
|
||||
[-0.9899925, 0.5403023, -0.6536436, 0.9887711],
|
||||
);
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[-0.14112, -0.84147096, 0.7568025, -0.14943814],
|
||||
);
|
||||
let y = x.sqr()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [9.0, 1.0, 16.0, 0.0225]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, 8.0, 0.3]);
|
||||
let y = x.sqr()?.sqrt()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3.0, 1.0, 4.0, 0.15]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1.0, 1.0, 1.0, 1.0]);
|
||||
let y = x.neg()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [-3.0, -1.0, -4.0, -0.15]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [-1.0, -1.0, -1.0, -1.0]);
|
||||
let y = x.affine(0.2, 1.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [1.6, 1.2, 1.8, 1.03]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [0.2, 0.2, 0.2, 0.2]);
|
||||
let y = Tensor::new(1f32, device)?.broadcast_div(x)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[-0.11111111, -1.0, -0.0625, -44.444443],
|
||||
);
|
||||
let y = x.broadcast_div(&Tensor::new(0.5f32, device)?)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [6., 2., 8., 0.3]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [2., 2., 2., 2.]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ pub fn main() -> Result<()> {
|
||||
let train_label_mask = Tensor::from_vec(train_label_mask, (train_labels.len(), LABELS), &dev)?;
|
||||
let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?;
|
||||
let bs = Var::zeros(LABELS, DType::F32, &dev)?;
|
||||
let sgd = candle_nn::SGD::new(&[&ws, &bs], 3e-1);
|
||||
let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0);
|
||||
let test_images = m.test_images;
|
||||
let test_labels = m.test_labels.to_vec1::<u8>()?;
|
||||
for epoch in 1..200 {
|
||||
|
Reference in New Issue
Block a user