mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Tweaks to softmax. (#745)
This commit is contained in:
@ -103,14 +103,12 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
.zip(dst.par_chunks_mut(dim_m1))
|
||||
.for_each(|(src, dst)| {
|
||||
let mut max = T::neg_infinity();
|
||||
for &s in src.iter() {
|
||||
max = T::max(s, max)
|
||||
}
|
||||
let mut sum_exp = T::zero();
|
||||
unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
|
||||
for (s, d) in src.iter().zip(dst.iter_mut()) {
|
||||
*d = (*s - max).exp();
|
||||
sum_exp += *d
|
||||
}
|
||||
let mut sum_exp = T::zero();
|
||||
unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
|
||||
for d in dst.iter_mut() {
|
||||
*d /= sum_exp
|
||||
}
|
||||
|
Reference in New Issue
Block a user