Add the elu op. (#113)

This commit is contained in:
Laurent Mazare
2023-07-09 21:56:31 +01:00
committed by GitHub
parent ea5dfa69bc
commit 270997a055
8 changed files with 95 additions and 0 deletions

View File

@ -463,6 +463,14 @@ fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize)
Ok(())
}
fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
if v.is_sign_positive() {
v
} else {
(v.exp() - T::one()) * alpha
}
}
impl CpuStorage {
pub fn dtype(&self) -> DType {
match self {
@ -666,6 +674,30 @@ impl CpuStorage {
Affine(mul, add).map(self, layout)
}
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
match self {
Self::BF16(storage) => {
let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
Ok(Self::BF16(data))
}
Self::F16(storage) => {
let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
Ok(Self::F16(data))
}
Self::F32(storage) => {
let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
Ok(Self::F32(data))
}
Self::F64(storage) => {
let data = unary_map(storage, layout, |v| elu(v, alpha));
Ok(Self::F64(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu")),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu")),
}
}
pub(crate) fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
match self {
Self::BF16(storage) => {