Add the gather op. (#219)

* Start adding gather.

* Gather cpu implementation + use in simple training.

* Add scatter_add for the gradient of gather.

* Simple cpu implementation of scatter_add.

* Use gather in the simple-training backprop.
This commit is contained in:
Laurent Mazare
2023-07-22 08:21:28 +02:00
committed by GitHub
parent 6eeea1b04e
commit 52c5d8c087
9 changed files with 315 additions and 14 deletions

View File

@ -66,6 +66,8 @@ pub(crate) enum Op {
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
Gather(Tensor, Tensor, usize),
ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
IndexAdd(Tensor, Tensor, Tensor, usize),
WhereCond(Tensor, Tensor, Tensor),