mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add an optimized case when performing the softmax over the last dimension. (#195)
This commit is contained in:
@ -785,19 +785,33 @@ fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize)
|
|||||||
let elem_per_slice = dims[dim];
|
let elem_per_slice = dims[dim];
|
||||||
let prod_pre_dim = dims[..dim].iter().product();
|
let prod_pre_dim = dims[..dim].iter().product();
|
||||||
let prod_post_dim = dims[dim + 1..].iter().product();
|
let prod_post_dim = dims[dim + 1..].iter().product();
|
||||||
for pre_idx in 0..prod_pre_dim {
|
if prod_post_dim == 1 {
|
||||||
for post_idx in 0..prod_post_dim {
|
for pre_idx in 0..prod_pre_dim {
|
||||||
let mut sum = 0f64;
|
let mut sum = 0f64;
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
let idx = pre_idx * elem_per_slice;
|
||||||
for _ in 0..elem_per_slice {
|
for v in s[idx..idx + elem_per_slice].iter() {
|
||||||
sum += s[idx].to_f64();
|
sum += v.to_f64();
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
}
|
||||||
let sum = T::from_f64(sum);
|
let sum = T::from_f64(sum);
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
for v in s[idx..idx + elem_per_slice].iter_mut() {
|
||||||
for _ in 0..elem_per_slice {
|
*v /= sum
|
||||||
s[idx] /= sum;
|
}
|
||||||
idx += prod_post_dim
|
}
|
||||||
|
} else {
|
||||||
|
for pre_idx in 0..prod_pre_dim {
|
||||||
|
for post_idx in 0..prod_post_dim {
|
||||||
|
let mut sum = 0f64;
|
||||||
|
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||||
|
for _ in 0..elem_per_slice {
|
||||||
|
sum += s[idx].to_f64();
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
let sum = T::from_f64(sum);
|
||||||
|
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||||
|
for _ in 0..elem_per_slice {
|
||||||
|
s[idx] /= sum;
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user