mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the gather op. (#219)
* Start adding gather. * Gather cpu implementation + use in simple training. * Add scatter_add for the gradient of gather. * Simple cpu implementation of scatter_add. * Use gather in the simple-training backprop.
This commit is contained in:
@ -39,6 +39,7 @@ impl Tensor {
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
| Op::ScatterAdd(t1, t2, t3, _)
|
||||
| Op::CustomOp3(t1, t2, t3, _)
|
||||
| Op::WhereCond(t1, t2, t3) => {
|
||||
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||
@ -56,6 +57,7 @@ impl Tensor {
|
||||
}
|
||||
| Op::CustomOp2(lhs, rhs, _)
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
@ -162,6 +164,11 @@ impl Tensor {
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Gather(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
}
|
||||
Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?,
|
||||
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
|
||||
Op::IndexSelect(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user