mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add the elu op. (#113)
This commit is contained in:
@ -463,6 +463,14 @@ fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
|
||||
if v.is_sign_positive() {
|
||||
v
|
||||
} else {
|
||||
(v.exp() - T::one()) * alpha
|
||||
}
|
||||
}
|
||||
|
||||
impl CpuStorage {
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
@ -666,6 +674,30 @@ impl CpuStorage {
|
||||
Affine(mul, add).map(self, layout)
|
||||
}
|
||||
|
||||
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
let data = unary_map(storage, layout, |v| elu(v, alpha));
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu")),
|
||||
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu")),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
|
Reference in New Issue
Block a user