mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Support backprop for a few more ops. (#254)
This commit is contained in:
@ -169,8 +169,22 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||||
}
|
}
|
||||||
Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?,
|
Op::ScatterAdd(init, indexes, src, dim) => {
|
||||||
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
|
let init_sum_grad = grads.or_insert(init)?;
|
||||||
|
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||||
|
|
||||||
|
let src_grad = grad.gather(indexes, *dim)?;
|
||||||
|
let src_sum_grad = grads.or_insert(src)?;
|
||||||
|
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||||
|
}
|
||||||
|
Op::IndexAdd(init, indexes, src, dim) => {
|
||||||
|
let init_sum_grad = grads.or_insert(init)?;
|
||||||
|
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||||
|
|
||||||
|
let src_grad = grad.index_select(indexes, *dim)?;
|
||||||
|
let src_sum_grad = grads.or_insert(src)?;
|
||||||
|
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||||
|
}
|
||||||
Op::IndexSelect(arg, indexes, dim) => {
|
Op::IndexSelect(arg, indexes, dim) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
|
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
|
||||||
@ -228,7 +242,7 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad)?;
|
*sum_grad = sum_grad.add(&grad)?;
|
||||||
}
|
}
|
||||||
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
|
Op::Cmp(_args, _) => {}
|
||||||
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
||||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||||
@ -268,7 +282,12 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
|
Op::Unary(arg, UnaryOp::Abs) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
let ones = arg.ones_like()?;
|
||||||
|
let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
|
||||||
|
*sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
|
||||||
|
}
|
||||||
Op::Unary(arg, UnaryOp::Exp) => {
|
Op::Unary(arg, UnaryOp::Exp) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
||||||
@ -303,12 +322,8 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Reduce(_, ReduceOp::ArgMin, _) => {
|
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
||||||
Err(Error::BackwardNotSupported { op: "argmin" })?
|
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
||||||
}
|
|
||||||
Op::Reduce(_, ReduceOp::ArgMax, _) => {
|
|
||||||
Err(Error::BackwardNotSupported { op: "argmax" })?
|
|
||||||
}
|
|
||||||
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
|
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
|
||||||
Op::Reshape(arg) => {
|
Op::Reshape(arg) => {
|
||||||
let arg_grad = grad.reshape(arg.dims())?;
|
let arg_grad = grad.reshape(arg.dims())?;
|
||||||
@ -316,7 +331,11 @@ impl Tensor {
|
|||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||||
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
|
Op::Unary(arg, UnaryOp::Relu) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||||
|
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||||
|
}
|
||||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||||
Op::CustomOp1(arg, c) => {
|
Op::CustomOp1(arg, c) => {
|
||||||
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
|
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
|
||||||
|
Reference in New Issue
Block a user