Add the gather op. (#219)

* Start adding gather.

* Gather cpu implementation + use in simple training.

* Add scatter_add for the gradient of gather.

* Simple cpu implementation of scatter_add.

* Use gather in the simple-training backprop.
This commit is contained in:
Laurent Mazare
2023-07-22 08:21:28 +02:00
committed by GitHub
parent 6eeea1b04e
commit 52c5d8c087
9 changed files with 315 additions and 14 deletions

View File

@ -325,6 +325,51 @@ impl Storage {
}
}
pub(crate) fn gather(
&self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(indexes, "index-add")?;
match (self, indexes) {
(Self::Cpu(s), Self::Cpu(indexes)) => {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(s), Self::Cuda(indexes)) => {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cuda(storage))
}
_ => unreachable!(),
}
}
pub(crate) fn scatter_add(
&self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(indexes, "scatter-add")?;
self.same_device(source, "scatter-add")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
_ => unreachable!(),
}
}
pub(crate) fn index_add(
&self,
l: &Layout,