Add a custom softmax implementation. (#744)

* Add a custom softmax implementation.

* Add softmaxlastdim to the benchmarks.

* And add a test.

* Support more dtypes.

* Polish the code.

* Use the slow implementation on cuda.

* Add a todo for the cuda kernel.
This commit is contained in:
Laurent Mazare
2023-09-05 15:20:23 +02:00
committed by GitHub
parent a8410bf35e
commit 1c9e5394a5
5 changed files with 109 additions and 18 deletions

View File

@ -1,4 +1,5 @@
use candle::{Result, Tensor};
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use rayon::prelude::*;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
@ -77,3 +78,69 @@ impl Dropout {
}
}
}
struct SoftmaxLastDim;
impl candle::CustomOp1 for SoftmaxLastDim {
fn name(&self) -> &'static str {
"softmax-last-dim"
}
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
fn softmax<T: candle::WithDType + num_traits::Float>(
src: &[T],
layout: &Layout,
) -> Result<(CpuStorage, Shape)> {
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => &src[o1..o2],
};
let el_count = layout.shape().elem_count();
let dims = layout.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let mut dst = vec![T::zero(); el_count];
src.par_chunks(dim_m1)
.zip(dst.par_chunks_mut(dim_m1))
.for_each(|(src, dst)| {
let mut max = T::neg_infinity();
for &s in src.iter() {
max = T::max(s, max)
}
let mut sum_exp = T::zero();
for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = (*s - max).exp();
sum_exp += *d
}
for d in dst.iter_mut() {
*d /= sum_exp
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, Shape::from_dims(dims)))
}
match storage {
CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
_ => candle::bail!("unsupported dtype for softmax {:?}", storage),
}
}
fn cuda_fwd(
&self,
_storage: &candle::CudaStorage,
_layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
candle::bail!("TODO: implement a cuda kernel")
}
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
if xs.device().is_cpu() {
xs.apply_op1_no_bwd(&SoftmaxLastDim)
} else {
softmax(xs, candle::D::Minus1)
}
}