mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
This commit is contained in:
@ -1236,45 +1236,6 @@ impl Map2 for MatMul {
|
||||
}
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way starting at offset 0.
|
||||
let dims = shape.dims();
|
||||
let elem_per_slice = dims[dim];
|
||||
let prod_pre_dim = dims[..dim].iter().product();
|
||||
let prod_post_dim = dims[dim + 1..].iter().product();
|
||||
if prod_post_dim == 1 {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
let mut sum = 0f64;
|
||||
let idx = pre_idx * elem_per_slice;
|
||||
for v in s[idx..idx + elem_per_slice].iter() {
|
||||
sum += v.to_f64();
|
||||
}
|
||||
let sum = T::from_f64(sum);
|
||||
for v in s[idx..idx + elem_per_slice].iter_mut() {
|
||||
*v /= sum
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
let mut sum = 0f64;
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
sum += s[idx].to_f64();
|
||||
idx += prod_post_dim
|
||||
}
|
||||
let sum = T::from_f64(sum);
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
s[idx] /= sum;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
|
||||
if v.is_sign_positive() {
|
||||
v
|
||||
@ -1513,17 +1474,6 @@ impl BackendStorage for CpuStorage {
|
||||
Cmp(op).map(self, lhs_l, rhs, rhs_l)
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way starting at offset 0.
|
||||
match self {
|
||||
Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::U8(_) | Self::U32(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
Affine(mul, add).map(self, layout)
|
||||
}
|
||||
|
Reference in New Issue
Block a user