mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
This commit is contained in:
@ -21,8 +21,6 @@ pub trait BackendStorage: Sized {
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
|
||||
|
@ -90,7 +90,6 @@ impl Tensor {
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Softmax(node, _)
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _)
|
||||
| Op::CustomOp1(node, _) => {
|
||||
@ -324,7 +323,6 @@ impl Tensor {
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
||||
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
||||
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
|
||||
Op::Reshape(arg) => {
|
||||
let arg_grad = grad.reshape(arg.dims())?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -1236,45 +1236,6 @@ impl Map2 for MatMul {
|
||||
}
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way starting at offset 0.
|
||||
let dims = shape.dims();
|
||||
let elem_per_slice = dims[dim];
|
||||
let prod_pre_dim = dims[..dim].iter().product();
|
||||
let prod_post_dim = dims[dim + 1..].iter().product();
|
||||
if prod_post_dim == 1 {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
let mut sum = 0f64;
|
||||
let idx = pre_idx * elem_per_slice;
|
||||
for v in s[idx..idx + elem_per_slice].iter() {
|
||||
sum += v.to_f64();
|
||||
}
|
||||
let sum = T::from_f64(sum);
|
||||
for v in s[idx..idx + elem_per_slice].iter_mut() {
|
||||
*v /= sum
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
let mut sum = 0f64;
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
sum += s[idx].to_f64();
|
||||
idx += prod_post_dim
|
||||
}
|
||||
let sum = T::from_f64(sum);
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
s[idx] /= sum;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
|
||||
if v.is_sign_positive() {
|
||||
v
|
||||
@ -1513,17 +1474,6 @@ impl BackendStorage for CpuStorage {
|
||||
Cmp(op).map(self, lhs_l, rhs, rhs_l)
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way starting at offset 0.
|
||||
match self {
|
||||
Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::U8(_) | Self::U32(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
Affine(mul, add).map(self, layout)
|
||||
}
|
||||
|
@ -1303,10 +1303,6 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
|
||||
}
|
||||
|
||||
fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = U::V.map(&self.slice, &device, layout)?;
|
||||
|
@ -49,10 +49,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -93,7 +93,6 @@ pub enum Op {
|
||||
Broadcast(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
Reshape(Tensor),
|
||||
Softmax(Tensor, usize),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
Elu(Tensor, f64),
|
||||
|
@ -125,15 +125,6 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
// This assumes a contiguous layout and no offset.
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
|
@ -553,40 +553,6 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
|
||||
/// let a = a.softmax(1)?;
|
||||
/// assert_eq!(
|
||||
/// a.to_vec2::<f32>()?,
|
||||
/// &[
|
||||
/// [0.13447072, 0.3655293, 0.13447072, 0.3655293],
|
||||
/// [0.004892866, 0.26714143, 0.7261657, 0.0017999847],
|
||||
/// ]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "softmax")?;
|
||||
// TODO: unify the two branches.
|
||||
if self.device().is_cuda() {
|
||||
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||
// the operation.
|
||||
let exp = self.exp()?;
|
||||
let sum_exp = exp.sum_keepdim(dim)?;
|
||||
exp.broadcast_div(&sum_exp)
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?;
|
||||
// The resulting storage is contiguous.
|
||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||
let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim));
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
}
|
||||
|
||||
fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
|
||||
match dims {
|
||||
[] => Ok(self),
|
||||
|
Reference in New Issue
Block a user