Support backprop for a few more ops. (#254)

This commit is contained in:
Laurent Mazare
2023-07-26 21:31:54 +01:00
committed by GitHub
parent 4f92420132
commit 89ba005962

View File

@ -169,8 +169,22 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
} }
Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?, Op::ScatterAdd(init, indexes, src, dim) => {
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?, let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;
let src_grad = grad.gather(indexes, *dim)?;
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::IndexAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;
let src_grad = grad.index_select(indexes, *dim)?;
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::IndexSelect(arg, indexes, dim) => { Op::IndexSelect(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?; *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
@ -228,7 +242,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?; *sum_grad = sum_grad.add(&grad)?;
} }
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }), Op::Cmp(_args, _) => {}
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => { Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?; let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?; let grad = broadcast_back(arg, &grad, reduced_dims)?;
@ -268,7 +282,12 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)? *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
} }
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?, Op::Unary(arg, UnaryOp::Abs) => {
let sum_grad = grads.or_insert(arg)?;
let ones = arg.ones_like()?;
let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
*sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
}
Op::Unary(arg, UnaryOp::Exp) => { Op::Unary(arg, UnaryOp::Exp) => {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * *node)?)? *sum_grad = sum_grad.add(&(&grad * *node)?)?
@ -303,12 +322,8 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.add(&arg_grad)?
} }
Op::Reduce(_, ReduceOp::ArgMin, _) => { Op::Reduce(_, ReduceOp::ArgMin, _) => {}
Err(Error::BackwardNotSupported { op: "argmin" })? Op::Reduce(_, ReduceOp::ArgMax, _) => {}
}
Op::Reduce(_, ReduceOp::ArgMax, _) => {
Err(Error::BackwardNotSupported { op: "argmax" })?
}
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?, Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => { Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?; let arg_grad = grad.reshape(arg.dims())?;
@ -316,7 +331,11 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.add(&arg_grad)?
} }
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?, Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?, Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
}
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?, Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::CustomOp1(arg, c) => { Op::CustomOp1(arg, c) => {
if let Some(arg_grad) = c.bwd(arg, node, &grad)? { if let Some(arg_grad) = c.bwd(arg, node, &grad)? {