Softmax numerical stability. (#267)

* Softmax numerical stability.

* Fix the flash-attn test.
This commit is contained in:
Laurent Mazare
2023-07-28 13:13:01 +01:00
committed by GitHub
parent 68eab38de6
commit 3eb2bc6d07
28 changed files with 117 additions and 188 deletions

View File

@ -21,8 +21,6 @@ pub trait BackendStorage: Sized {
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;

View File

@ -90,7 +90,6 @@ impl Tensor {
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Narrow(node, _, _, _)
| Op::Softmax(node, _)
| Op::Unary(node, _)
| Op::Elu(node, _)
| Op::CustomOp1(node, _) => {
@ -324,7 +323,6 @@ impl Tensor {
}
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;

View File

@ -1236,45 +1236,6 @@ impl Map2 for MatMul {
}
}
fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> {
// [self] stores data in a contiguous way starting at offset 0.
let dims = shape.dims();
let elem_per_slice = dims[dim];
let prod_pre_dim = dims[..dim].iter().product();
let prod_post_dim = dims[dim + 1..].iter().product();
if prod_post_dim == 1 {
for pre_idx in 0..prod_pre_dim {
let mut sum = 0f64;
let idx = pre_idx * elem_per_slice;
for v in s[idx..idx + elem_per_slice].iter() {
sum += v.to_f64();
}
let sum = T::from_f64(sum);
for v in s[idx..idx + elem_per_slice].iter_mut() {
*v /= sum
}
}
} else {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += s[idx].to_f64();
idx += prod_post_dim
}
let sum = T::from_f64(sum);
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
s[idx] /= sum;
idx += prod_post_dim
}
}
}
}
Ok(())
}
fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
if v.is_sign_positive() {
v
@ -1513,17 +1474,6 @@ impl BackendStorage for CpuStorage {
Cmp(op).map(self, lhs_l, rhs, rhs_l)
}
fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
// [self] stores data in a contiguous way starting at offset 0.
match self {
Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim),
Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
Self::U8(_) | Self::U32(_) => Ok(()),
}
}
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
Affine(mul, add).map(self, layout)
}

View File

@ -1303,10 +1303,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
}
fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = U::V.map(&self.slice, &device, layout)?;

View File

@ -49,10 +49,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -93,7 +93,6 @@ pub enum Op {
Broadcast(Tensor),
Narrow(Tensor, usize, usize, usize),
Reshape(Tensor),
Softmax(Tensor, usize),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Elu(Tensor, f64),

View File

@ -125,15 +125,6 @@ impl Storage {
}
}
// This assumes a contiguous layout and no offset.
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
match self {
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
}
Ok(())
}
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
match self {
Storage::Cpu(storage) => {

View File

@ -553,40 +553,6 @@ impl Tensor {
}
}
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
///
/// ```rust
/// use candle::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
/// let a = a.softmax(1)?;
/// assert_eq!(
/// a.to_vec2::<f32>()?,
/// &[
/// [0.13447072, 0.3655293, 0.13447072, 0.3655293],
/// [0.004892866, 0.26714143, 0.7261657, 0.0017999847],
/// ]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "softmax")?;
// TODO: unify the two branches.
if self.device().is_cuda() {
// We do not have a cuda kernel for divide_by_sum_over_dim so split
// the operation.
let exp = self.exp()?;
let sum_exp = exp.sum_keepdim(dim)?;
exp.broadcast_div(&sum_exp)
} else {
let shape = self.shape();
let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?;
// The resulting storage is contiguous.
storage.divide_by_sum_over_dim(shape, dim)?;
let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim));
Ok(from_storage(storage, shape.clone(), op, false))
}
}
fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
match dims {
[] => Ok(self),