mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the gradient for reduce-sum. (#162)
* Add the gradient for reduce-sum. * And add the gradient for the broadcast ops. * Add some backprop tests. * Add some linear regression example.
This commit is contained in:
@ -16,6 +16,26 @@ fn simple_grad(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sum_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||
let x = x.as_tensor();
|
||||
let y = (x.sqr()?.sum(&[0])? * 2.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [52.]);
|
||||
// y = 2.x^2 so dy/dx = 4.x
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
|
||||
|
||||
// Same test as before but squeezing on the last dimension.
|
||||
let y = (x.sqr()?.sum(&[0])? * 2.)?.squeeze(0)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_scalar::<f32>()?, 52.);
|
||||
// y = 2.x^2 so dy/dx = 4.x
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn matmul_grad(device: &Device) -> Result<()> {
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let x = Var::from_slice(&data, (2, 2, 3), device)?;
|
||||
@ -60,5 +80,6 @@ fn grad_descent(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||
|
Reference in New Issue
Block a user