mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Use a macro to handle the dtype pattern matching. (#215)
This commit is contained in:
@ -1,15 +1,15 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cpu_backend;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
|
||||
mod test_utils;
|
||||
use test_utils::to_vec1_round;
|
||||
|
||||
fn fwd<T: num_traits::Float>(v: T, alpha: T) -> T {
|
||||
fn fwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
|
||||
if v.is_sign_positive() {
|
||||
v
|
||||
} else {
|
||||
let alpha = T::from(alpha).unwrap_or(T::nan());
|
||||
(v.exp() - T::one()) * alpha
|
||||
}
|
||||
}
|
||||
@ -24,33 +24,12 @@ impl CustomOp1 for Elu {
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
use CpuStorage::*;
|
||||
|
||||
// In this example, we pattern match over the different dtypes. Some helper functions and
|
||||
// traits from the `cpu_backend` module can be used to avoid this in some common cases, see
|
||||
// e.g. `Map1`.
|
||||
let storage = match s {
|
||||
BF16(s) => {
|
||||
let alpha = bf16::from_f64(self.alpha);
|
||||
let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
|
||||
BF16(data)
|
||||
}
|
||||
F16(s) => {
|
||||
let alpha = f16::from_f64(self.alpha);
|
||||
let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
|
||||
F16(data)
|
||||
}
|
||||
F32(s) => {
|
||||
let alpha = self.alpha as f32;
|
||||
let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
|
||||
F32(data)
|
||||
}
|
||||
F64(s) => {
|
||||
let data = cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha));
|
||||
F64(data)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?,
|
||||
};
|
||||
let storage = candle::map_dtype!(
|
||||
"elu",
|
||||
s,
|
||||
|s| cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha)),
|
||||
(BF16, F16, F32, F64)
|
||||
);
|
||||
Ok((storage, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
@ -69,10 +48,11 @@ fn custom_op1_no_backward() -> Result<()> {
|
||||
}
|
||||
|
||||
// Define a similar struct as Elu but with backward support.
|
||||
fn bwd<T: num_traits::Float>(v: T, alpha: T) -> T {
|
||||
fn bwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
|
||||
if v.is_sign_positive() {
|
||||
T::one()
|
||||
} else {
|
||||
let alpha = T::from(alpha).unwrap_or(T::nan());
|
||||
v.exp() * alpha
|
||||
}
|
||||
}
|
||||
@ -87,33 +67,12 @@ impl CustomOp1 for EluBackward {
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
use CpuStorage::*;
|
||||
|
||||
// In this example, we pattern match over the different dtypes. Some helper functions and
|
||||
// traits from the `cpu_backend` module can be used to avoid this in some common cases, see
|
||||
// e.g. `Map1`.
|
||||
let storage = match s {
|
||||
BF16(s) => {
|
||||
let alpha = bf16::from_f64(self.alpha);
|
||||
let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
|
||||
BF16(data)
|
||||
}
|
||||
F16(s) => {
|
||||
let alpha = f16::from_f64(self.alpha);
|
||||
let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
|
||||
F16(data)
|
||||
}
|
||||
F32(s) => {
|
||||
let alpha = self.alpha as f32;
|
||||
let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
|
||||
F32(data)
|
||||
}
|
||||
F64(s) => {
|
||||
let data = cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha));
|
||||
F64(data)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?,
|
||||
};
|
||||
let storage = candle::map_dtype!(
|
||||
"elu-bwd",
|
||||
s,
|
||||
|s| cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha)),
|
||||
(BF16, F16, F32, F64)
|
||||
);
|
||||
Ok((storage, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user