Add the gather op. (#219)

* Start adding gather.

* Gather cpu implementation + use in simple training.

* Add scatter_add for the gradient of gather.

* Simple cpu implementation of scatter_add.

* Use gather in the simple-training backprop.
This commit is contained in:
Laurent Mazare
2023-07-22 08:21:28 +02:00
committed by GitHub
parent 6eeea1b04e
commit 52c5d8c087
9 changed files with 315 additions and 14 deletions

View File

@ -39,6 +39,7 @@ impl Tensor {
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
| Op::ScatterAdd(t1, t2, t3, _)
| Op::CustomOp3(t1, t2, t3, _)
| Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
@ -56,6 +57,7 @@ impl Tensor {
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
@ -162,6 +164,11 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
}
Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?,
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
Op::IndexSelect(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;