mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add a custom softmax implementation. (#744)
* Add a custom softmax implementation. * Add softmaxlastdim to the benchmarks. * And add a test. * Support more dtypes. * Polish the code. * Use the slow implementation on cuda. * Add a todo for the cuda kernel.
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
use candle::{Result, Tensor};
|
||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
|
||||
@ -77,3 +78,69 @@ impl Dropout {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SoftmaxLastDim;
|
||||
|
||||
impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
fn name(&self) -> &'static str {
|
||||
"softmax-last-dim"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
fn softmax<T: candle::WithDType + num_traits::Float>(
|
||||
src: &[T],
|
||||
layout: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
None => candle::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let el_count = layout.shape().elem_count();
|
||||
let dims = layout.shape().dims();
|
||||
let dim_m1 = dims[dims.len() - 1];
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(dim_m1)
|
||||
.zip(dst.par_chunks_mut(dim_m1))
|
||||
.for_each(|(src, dst)| {
|
||||
let mut max = T::neg_infinity();
|
||||
for &s in src.iter() {
|
||||
max = T::max(s, max)
|
||||
}
|
||||
let mut sum_exp = T::zero();
|
||||
for (s, d) in src.iter().zip(dst.iter_mut()) {
|
||||
*d = (*s - max).exp();
|
||||
sum_exp += *d
|
||||
}
|
||||
for d in dst.iter_mut() {
|
||||
*d /= sum_exp
|
||||
}
|
||||
});
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, Shape::from_dims(dims)))
|
||||
}
|
||||
|
||||
match storage {
|
||||
CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
|
||||
CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
|
||||
CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
|
||||
CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
|
||||
_ => candle::bail!("unsupported dtype for softmax {:?}", storage),
|
||||
}
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_storage: &candle::CudaStorage,
|
||||
_layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
candle::bail!("TODO: implement a cuda kernel")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
if xs.device().is_cpu() {
|
||||
xs.apply_op1_no_bwd(&SoftmaxLastDim)
|
||||
} else {
|
||||
softmax(xs, candle::D::Minus1)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user