diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index d2c727f3..449c1e8b 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -785,19 +785,33 @@ fn divide_by_sum_over_dim(s: &mut [T], shape: &Shape, dim: usize) let elem_per_slice = dims[dim]; let prod_pre_dim = dims[..dim].iter().product(); let prod_post_dim = dims[dim + 1..].iter().product(); - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { + if prod_post_dim == 1 { + for pre_idx in 0..prod_pre_dim { let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += s[idx].to_f64(); - idx += prod_post_dim + let idx = pre_idx * elem_per_slice; + for v in s[idx..idx + elem_per_slice].iter() { + sum += v.to_f64(); } let sum = T::from_f64(sum); - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - s[idx] /= sum; - idx += prod_post_dim + for v in s[idx..idx + elem_per_slice].iter_mut() { + *v /= sum + } + } + } else { + for pre_idx in 0..prod_pre_dim { + for post_idx in 0..prod_post_dim { + let mut sum = 0f64; + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + sum += s[idx].to_f64(); + idx += prod_post_dim + } + let sum = T::from_f64(sum); + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + s[idx] /= sum; + idx += prod_post_dim + } } } }