Add the scatter op. (#2921)

* Add the scatter op.

* Backprop support.

* Cuda support.
This commit is contained in:
Laurent Mazare
2025-04-25 21:46:58 +02:00
committed by GitHub
parent 3aeb9575c7
commit 3827685524
15 changed files with 429 additions and 19 deletions

View File

@ -628,6 +628,34 @@ impl Storage {
}
}
pub(crate) fn scatter(
&self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(indexes, "scatter-add")?;
self.same_device(source, "scatter-add")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
pub(crate) fn scatter_add(
&self,
l: &Layout,