mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the gather op. (#219)
* Start adding gather. * Gather cpu implementation + use in simple training. * Add scatter_add for the gradient of gather. * Simple cpu implementation of scatter_add. * Use gather in the simple-training backprop.
This commit is contained in:
@ -325,6 +325,51 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn gather(
|
||||
&self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(indexes, "index-add")?;
|
||||
match (self, indexes) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes)) => {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes)) => {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(indexes, "scatter-add")?;
|
||||
self.same_device(source, "scatter-add")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn index_add(
|
||||
&self,
|
||||
l: &Layout,
|
||||
|
Reference in New Issue
Block a user