Add the gradient for reduce-sum. (#162)

* Add the gradient for reduce-sum.

* And add the gradient for the broadcast ops.

* Add some backprop tests.

* Add some linear regression example.
This commit is contained in:
Laurent Mazare
2023-07-13 20:14:10 +01:00
committed by GitHub
parent 3c02ea56b0
commit 23e105cd94
4 changed files with 74 additions and 7 deletions

View File

@ -179,11 +179,33 @@ impl Tensor {
start_idx += len;
}
}
Op::Broadcast(_arg) => {
return Err(Error::BackwardNotSupported { op: "broadcast" })
Op::Broadcast(arg) => {
let arg_dims = arg.dims();
let node_dims = node.dims();
// The number of dims that have been inserted on the left.
let left_dims = node_dims.len() - arg_dims.len();
let mut sum_dims: Vec<usize> = (0..left_dims).collect();
for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
.iter()
.zip(arg_dims.iter())
.enumerate()
{
if node_dim != arg_dim {
sum_dims.push(dim + left_dims)
}
}
let mut arg_grad = grad.sum(sum_dims.as_slice())?;
// sum_dims has increasing values.
for &dim in sum_dims.iter().rev() {
arg_grad = arg_grad.squeeze(dim)?
}
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
}
Op::Sum(_arg, _sum_dims) => {
return Err(Error::BackwardNotSupported { op: "sum" })
Op::Sum(arg, _sum_dims) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&grad)?
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;

View File

@ -16,6 +16,26 @@ fn simple_grad(device: &Device) -> Result<()> {
Ok(())
}
fn sum_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., 4.], device)?;
let x = x.as_tensor();
let y = (x.sqr()?.sum(&[0])? * 2.)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [52.]);
// y = 2.x^2 so dy/dx = 4.x
assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
// Same test as before but squeezing on the last dimension.
let y = (x.sqr()?.sum(&[0])? * 2.)?.squeeze(0)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_scalar::<f32>()?, 52.);
// y = 2.x^2 so dy/dx = 4.x
assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
Ok(())
}
fn matmul_grad(device: &Device) -> Result<()> {
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let x = Var::from_slice(&data, (2, 2, 3), device)?;
@ -60,5 +80,6 @@ fn grad_descent(device: &Device) -> Result<()> {
}
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);