mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
More realistic training setup. (#210)
* More realistic training setup. * Compute the model accuracy. * Very inefficient backprop for index select. * More backprop. * Fix some backprop issues. * Backprop fix. * Another broadcasting backprop fix. * Better backprop for reducing ops. * Training again. * Add some gradient tests. * Get the training to work.
This commit is contained in:
@ -2,6 +2,19 @@ use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::{Error, Result, Tensor, TensorId};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// arg has been reduced to node via reduce_dims, expand it back to arg.
|
||||
// This has to handle keepdims.
|
||||
fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
|
||||
if arg.rank() == node.rank() {
|
||||
// keepdim = true
|
||||
node.broadcast_as(arg.shape())
|
||||
} else {
|
||||
// keepdim = false
|
||||
// first expand the reduced dims.
|
||||
node.reshape(reduced_dims)?.broadcast_as(arg.shape())
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
@ -145,8 +158,26 @@ impl Tensor {
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::IndexSelect(_lhs, _rhs, _) => {
|
||||
Err(Error::BackwardNotSupported { op: "index-select" })?
|
||||
Op::IndexSelect(arg, indexes, dim) => {
|
||||
let dim = *dim;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// TODO: This is very very very inefficient, have some dedicated kernel for this.
|
||||
let indexes = indexes.to_vec1::<u32>()?;
|
||||
for (dst_index, src_index) in indexes.iter().enumerate() {
|
||||
let src_index = *src_index as usize;
|
||||
let dst_grad_for_index = grad.narrow(dim, dst_index, 1)?;
|
||||
let mut pre_dims = arg.dims().to_vec();
|
||||
pre_dims[dim] = src_index;
|
||||
let pre_zeros =
|
||||
Tensor::zeros(pre_dims, sum_grad.dtype(), sum_grad.device())?;
|
||||
let mut post_dims = arg.dims().to_vec();
|
||||
post_dims[dim] = post_dims[dim] - src_index - 1;
|
||||
let post_zeros =
|
||||
Tensor::zeros(post_dims, sum_grad.dtype(), sum_grad.device())?;
|
||||
let src_grad =
|
||||
Tensor::cat(&[pre_zeros, dst_grad_for_index, post_zeros], dim)?;
|
||||
*sum_grad = sum_grad.add(&src_grad)?;
|
||||
}
|
||||
}
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
Err(Error::BackwardNotSupported { op: "embedding" })?
|
||||
@ -189,20 +220,32 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
let arg_grad = grad.sum(sum_dims.as_slice())?;
|
||||
let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
|
||||
for _i in 0..left_dims {
|
||||
arg_grad = arg_grad.squeeze(0)?
|
||||
}
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
|
||||
*sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
|
||||
}
|
||||
Op::Reduce(arg, ReduceOp::Sum, _) => {
|
||||
Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
|
||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.broadcast_add(&grad)?
|
||||
*sum_grad = sum_grad.add(&grad)?;
|
||||
}
|
||||
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
|
||||
Op::Reduce(_args, ReduceOp::Max, _) => {
|
||||
Err(Error::BackwardNotSupported { op: "max" })?
|
||||
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||
let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
|
||||
}
|
||||
Op::Reduce(_args, ReduceOp::Min, _) => {
|
||||
Err(Error::BackwardNotSupported { op: "min" })?
|
||||
Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
|
||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||
let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -215,7 +258,7 @@ impl Tensor {
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Log) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
||||
*sum_grad = sum_grad.add(&(grad / arg)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Sin) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -228,7 +271,7 @@ impl Tensor {
|
||||
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
|
||||
Op::Unary(arg, UnaryOp::Exp) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
||||
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Neg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user