Use a macro to handle the dtype pattern matching. (#215)

This commit is contained in:
Laurent Mazare
2023-07-21 17:03:51 +02:00
committed by GitHub
parent a6bcdfb269
commit 4a100875bf
2 changed files with 26 additions and 57 deletions

View File

@ -1660,3 +1660,13 @@ impl BackendDevice for CpuDevice {
Ok(storage)
}
}
#[macro_export]
macro_rules! map_dtype {
($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => {
match $storage {
$(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)*
s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?,
}
};
}

View File

@ -1,15 +1,15 @@
use candle::backend::BackendStorage;
use candle::cpu_backend;
use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
mod test_utils;
use test_utils::to_vec1_round;
fn fwd<T: num_traits::Float>(v: T, alpha: T) -> T {
fn fwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
if v.is_sign_positive() {
v
} else {
let alpha = T::from(alpha).unwrap_or(T::nan());
(v.exp() - T::one()) * alpha
}
}
@ -24,33 +24,12 @@ impl CustomOp1 for Elu {
}
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
use CpuStorage::*;
// In this example, we pattern match over the different dtypes. Some helper functions and
// traits from the `cpu_backend` module can be used to avoid this in some common cases, see
// e.g. `Map1`.
let storage = match s {
BF16(s) => {
let alpha = bf16::from_f64(self.alpha);
let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
BF16(data)
}
F16(s) => {
let alpha = f16::from_f64(self.alpha);
let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
F16(data)
}
F32(s) => {
let alpha = self.alpha as f32;
let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
F32(data)
}
F64(s) => {
let data = cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha));
F64(data)
}
_ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?,
};
let storage = candle::map_dtype!(
"elu",
s,
|s| cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha)),
(BF16, F16, F32, F64)
);
Ok((storage, l.shape().clone()))
}
}
@ -69,10 +48,11 @@ fn custom_op1_no_backward() -> Result<()> {
}
// Define a similar struct as Elu but with backward support.
fn bwd<T: num_traits::Float>(v: T, alpha: T) -> T {
fn bwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
if v.is_sign_positive() {
T::one()
} else {
let alpha = T::from(alpha).unwrap_or(T::nan());
v.exp() * alpha
}
}
@ -87,33 +67,12 @@ impl CustomOp1 for EluBackward {
}
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
use CpuStorage::*;
// In this example, we pattern match over the different dtypes. Some helper functions and
// traits from the `cpu_backend` module can be used to avoid this in some common cases, see
// e.g. `Map1`.
let storage = match s {
BF16(s) => {
let alpha = bf16::from_f64(self.alpha);
let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
BF16(data)
}
F16(s) => {
let alpha = f16::from_f64(self.alpha);
let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
F16(data)
}
F32(s) => {
let alpha = self.alpha as f32;
let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
F32(data)
}
F64(s) => {
let data = cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha));
F64(data)
}
_ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?,
};
let storage = candle::map_dtype!(
"elu-bwd",
s,
|s| cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha)),
(BF16, F16, F32, F64)
);
Ok((storage, l.shape().clone()))
}
}