Boilerplate code for the sum operator.

This commit is contained in:
laurent
2023-06-25 09:35:17 +01:00
parent 7ccf27dda2
commit 3852a85af3
7 changed files with 61 additions and 1 deletions

View File

@ -72,6 +72,19 @@ impl Storage {
}
}
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.sum(shape, stride, s)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.sum(shape, stride, s)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
match self {
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,