Softmax tests + fix.

This commit is contained in:
laurent
2023-06-23 22:46:36 +01:00
parent d0a91db8fd
commit ae5dc5fbc6
3 changed files with 48 additions and 11 deletions

View File

@ -150,23 +150,22 @@ impl CpuStorage {
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
// [self] stores data in a contiguous way.
let dims = shape.dims();
let number_of_slices = dims[dim];
let elem_per_slice = dims[dim];
let prod_pre_dim = dims[..dim].iter().product();
let prod_post_dim = dims[dim + 1..].iter().product();
let elem_count = shape.elem_count();
match self {
Self::F32(storage) => {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
while idx < elem_count {
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += storage[idx] as f64;
idx += prod_post_dim
}
let sum = sum as f32;
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
while idx < elem_count {
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
storage[idx] /= sum;
idx += prod_post_dim
}
@ -177,13 +176,13 @@ impl CpuStorage {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
while idx < elem_count {
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += storage[idx];
idx += prod_post_dim
}
let mut idx = pre_idx * prod_post_dim * number_of_slices + post_idx;
while idx < elem_count {
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
storage[idx] /= sum;
idx += prod_post_dim
}