Add an optimized case when performing the softmax over the last dimension. (#195)

This commit is contained in:
Laurent Mazare
2023-07-18 18:59:50 +02:00
committed by GitHub
parent 6623c227d8
commit fd55fc9592

View File

@ -785,19 +785,33 @@ fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize)
let elem_per_slice = dims[dim]; let elem_per_slice = dims[dim];
let prod_pre_dim = dims[..dim].iter().product(); let prod_pre_dim = dims[..dim].iter().product();
let prod_post_dim = dims[dim + 1..].iter().product(); let prod_post_dim = dims[dim + 1..].iter().product();
for pre_idx in 0..prod_pre_dim { if prod_post_dim == 1 {
for post_idx in 0..prod_post_dim { for pre_idx in 0..prod_pre_dim {
let mut sum = 0f64; let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; let idx = pre_idx * elem_per_slice;
for _ in 0..elem_per_slice { for v in s[idx..idx + elem_per_slice].iter() {
sum += s[idx].to_f64(); sum += v.to_f64();
idx += prod_post_dim
} }
let sum = T::from_f64(sum); let sum = T::from_f64(sum);
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; for v in s[idx..idx + elem_per_slice].iter_mut() {
for _ in 0..elem_per_slice { *v /= sum
s[idx] /= sum; }
idx += prod_post_dim }
} else {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += s[idx].to_f64();
idx += prod_post_dim
}
let sum = T::from_f64(sum);
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
s[idx] /= sum;
idx += prod_post_dim
}
} }
} }
} }