Add the gather op. (#219)

* Start adding gather.

* Gather cpu implementation + use in simple training.

* Add scatter_add for the gradient of gather.

* Simple cpu implementation of scatter_add.

* Use gather in the simple-training backprop.
This commit is contained in:
Laurent Mazare
2023-07-22 08:21:28 +02:00
committed by GitHub
parent 6eeea1b04e
commit 52c5d8c087
9 changed files with 315 additions and 14 deletions

View File

@ -40,6 +40,16 @@ pub trait BackendStorage: Sized {
) -> Result<Self>;
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self>;
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
fn index_add(
&self,