More realistic training setup. (#210)

* More realistic training setup.

* Compute the model accuracy.

* Very inefficient backprop for index select.

* More backprop.

* Fix some backprop issues.

* Backprop fix.

* Another broadcasting backprop fix.

* Better backprop for reducing ops.

* Training again.

* Add some gradient tests.

* Get the training to work.
This commit is contained in:
Laurent Mazare
2023-07-20 19:25:41 +02:00
committed by GitHub
parent fa08fb3126
commit 4845d5cc64
6 changed files with 156 additions and 37 deletions

View File

@ -2,6 +2,19 @@ use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
// arg has been reduced to node via reduce_dims, expand it back to arg.
// This has to handle keepdims.
fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
if arg.rank() == node.rank() {
// keepdim = true
node.broadcast_as(arg.shape())
} else {
// keepdim = false
// first expand the reduced dims.
node.reshape(reduced_dims)?.broadcast_as(arg.shape())
}
}
impl Tensor {
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
@ -145,8 +158,26 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::IndexSelect(_lhs, _rhs, _) => {
Err(Error::BackwardNotSupported { op: "index-select" })?
Op::IndexSelect(arg, indexes, dim) => {
let dim = *dim;
let sum_grad = grads.or_insert(arg)?;
// TODO: This is very very very inefficient, have some dedicated kernel for this.
let indexes = indexes.to_vec1::<u32>()?;
for (dst_index, src_index) in indexes.iter().enumerate() {
let src_index = *src_index as usize;
let dst_grad_for_index = grad.narrow(dim, dst_index, 1)?;
let mut pre_dims = arg.dims().to_vec();
pre_dims[dim] = src_index;
let pre_zeros =
Tensor::zeros(pre_dims, sum_grad.dtype(), sum_grad.device())?;
let mut post_dims = arg.dims().to_vec();
post_dims[dim] = post_dims[dim] - src_index - 1;
let post_zeros =
Tensor::zeros(post_dims, sum_grad.dtype(), sum_grad.device())?;
let src_grad =
Tensor::cat(&[pre_zeros, dst_grad_for_index, post_zeros], dim)?;
*sum_grad = sum_grad.add(&src_grad)?;
}
}
Op::Embedding(_lhs, _rhs) => {
Err(Error::BackwardNotSupported { op: "embedding" })?
@ -189,20 +220,32 @@ impl Tensor {
}
}
let arg_grad = grad.sum(sum_dims.as_slice())?;
let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
for _i in 0..left_dims {
arg_grad = arg_grad.squeeze(0)?
}
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
*sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
}
Op::Reduce(arg, ReduceOp::Sum, _) => {
Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
let grad = broadcast_back(arg, &grad, reduced_dims)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&grad)?
*sum_grad = sum_grad.add(&grad)?;
}
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
Op::Reduce(_args, ReduceOp::Max, _) => {
Err(Error::BackwardNotSupported { op: "max" })?
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
}
Op::Reduce(_args, ReduceOp::Min, _) => {
Err(Error::BackwardNotSupported { op: "min" })?
Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
@ -215,7 +258,7 @@ impl Tensor {
}
Op::Unary(arg, UnaryOp::Log) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * *node)?)?
*sum_grad = sum_grad.add(&(grad / arg)?)?
}
Op::Unary(arg, UnaryOp::Sin) => {
let sum_grad = grads.or_insert(arg)?;
@ -228,7 +271,7 @@ impl Tensor {
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
Op::Unary(arg, UnaryOp::Exp) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad / arg)?)?
*sum_grad = sum_grad.add(&(&grad * *node)?)?
}
Op::Unary(arg, UnaryOp::Neg) => {
let sum_grad = grads.or_insert(arg)?;

View File

@ -52,7 +52,7 @@ mod mkl;
pub mod npy;
mod op;
pub mod safetensors;
mod shape;
pub mod shape;
mod storage;
mod strided_index;
mod tensor;

View File

@ -48,6 +48,7 @@ pub(crate) enum Op {
Binary(Tensor, Tensor, BinaryOp),
Unary(Tensor, UnaryOp),
Cmp(Tensor, CmpOp),
// The third argument is the reduced shape with `keepdim=true`.
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),

View File

@ -633,15 +633,15 @@ impl Tensor {
let storage = self
.storage()
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
let op = if self.track_op() {
Some(Op::Reduce(self.clone(), ReduceOp::Max, max_dims.to_vec()))
} else {
None
};
let mut dims = self.dims().to_vec();
for &max_dim in max_dims.iter() {
dims[max_dim] = 1
}
let op = if self.track_op() {
Some(Op::Reduce(self.clone(), ReduceOp::Max, dims.to_vec()))
} else {
None
};
let max = from_storage(storage, dims, op, false);
if keepdim {
Ok(max)
@ -655,15 +655,15 @@ impl Tensor {
let storage = self
.storage()
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
let op = if self.track_op() {
Some(Op::Reduce(self.clone(), ReduceOp::Min, min_dims.to_vec()))
} else {
None
};
let mut dims = self.dims().to_vec();
for &min_dim in min_dims.iter() {
dims[min_dim] = 1
}
let op = if self.track_op() {
Some(Op::Reduce(self.clone(), ReduceOp::Min, dims.to_vec()))
} else {
None
};
let min = from_storage(storage, dims, op, false);
if keepdim {
Ok(min)
@ -677,15 +677,15 @@ impl Tensor {
let storage = self
.storage()
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
let op = if self.track_op() {
Some(Op::Reduce(self.clone(), ReduceOp::Sum, sum_dims.to_vec()))
} else {
None
};
let mut dims = self.dims().to_vec();
for &sum_dim in sum_dims.iter() {
dims[sum_dim] = 1
}
let op = if self.track_op() {
Some(Op::Reduce(self.clone(), ReduceOp::Sum, dims.to_vec()))
} else {
None
};
let sum = from_storage(storage, dims, op, false);
if keepdim {
Ok(sum)

View File

@ -79,7 +79,42 @@ fn grad_descent(device: &Device) -> Result<()> {
Ok(())
}
fn unary_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., 4., 0.15], device)?;
let x = x.as_tensor();
let y = (x.log()? + 1.)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [2.0986123, 1.0, 2.3862944, -0.89712]);
assert_eq!(grad_x.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
let y = x.exp()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
y.to_vec1::<f32>()?,
[20.085537, 2.7182817, 54.59815, 1.1618342]
);
assert_eq!(
grad_x.to_vec1::<f32>()?,
[20.085537, 2.7182817, 54.59815, 1.1618342]
);
let y = x.exp()?.sqr()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
y.to_vec1::<f32>()?,
[403.4288, 7.3890557, 2980.9578, 1.3498588]
);
// exp(x)^2 = exp(2*x)
assert_eq!(
grad_x.to_vec1::<f32>()?,
[806.8576, 14.778111, 5961.9155, 2.6997175]
);
Ok(())
}
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);