Softmax cpu implementation.

This commit is contained in:
laurent
2023-06-23 22:26:53 +01:00
parent 8443963d4f
commit d0a91db8fd
5 changed files with 55 additions and 8 deletions

View File

@ -72,11 +72,12 @@ impl Storage {
}
}
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) {
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
match self {
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim),
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim),
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
}
Ok(())
}
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {