Tweaks to softmax. (#745)

This commit is contained in:
Laurent Mazare
2023-09-05 16:22:27 +02:00
committed by GitHub
parent 1c9e5394a5
commit 6615daf242
2 changed files with 84 additions and 19 deletions

View File

@ -103,14 +103,12 @@ impl candle::CustomOp1 for SoftmaxLastDim {
.zip(dst.par_chunks_mut(dim_m1))
.for_each(|(src, dst)| {
let mut max = T::neg_infinity();
for &s in src.iter() {
max = T::max(s, max)
}
let mut sum_exp = T::zero();
unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = (*s - max).exp();
sum_exp += *d
}
let mut sum_exp = T::zero();
unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
for d in dst.iter_mut() {
*d /= sum_exp
}