From 4a100875bff616843ce3f83a05113d03dcfccf9c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 21 Jul 2023 17:03:51 +0200 Subject: [PATCH] Use a macro to handle the dtype pattern matching. (#215) --- candle-core/src/cpu_backend.rs | 10 ++++ candle-core/tests/custom_op_tests.rs | 73 ++++++---------------------- 2 files changed, 26 insertions(+), 57 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index d529b173..a471e308 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1660,3 +1660,13 @@ impl BackendDevice for CpuDevice { Ok(storage) } } + +#[macro_export] +macro_rules! map_dtype { + ($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => { + match $storage { + $(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)* + s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?, + } + }; +} diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs index 3e1e0c19..3ce125bc 100644 --- a/candle-core/tests/custom_op_tests.rs +++ b/candle-core/tests/custom_op_tests.rs @@ -1,15 +1,15 @@ use candle::backend::BackendStorage; use candle::cpu_backend; use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor}; -use half::{bf16, f16}; mod test_utils; use test_utils::to_vec1_round; -fn fwd(v: T, alpha: T) -> T { +fn fwd(v: T, alpha: f64) -> T { if v.is_sign_positive() { v } else { + let alpha = T::from(alpha).unwrap_or(T::nan()); (v.exp() - T::one()) * alpha } } @@ -24,33 +24,12 @@ impl CustomOp1 for Elu { } fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { - use CpuStorage::*; - - // In this example, we pattern match over the different dtypes. Some helper functions and - // traits from the `cpu_backend` module can be used to avoid this in some common cases, see - // e.g. `Map1`. - let storage = match s { - BF16(s) => { - let alpha = bf16::from_f64(self.alpha); - let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha)); - BF16(data) - } - F16(s) => { - let alpha = f16::from_f64(self.alpha); - let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha)); - F16(data) - } - F32(s) => { - let alpha = self.alpha as f32; - let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha)); - F32(data) - } - F64(s) => { - let data = cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha)); - F64(data) - } - _ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?, - }; + let storage = candle::map_dtype!( + "elu", + s, + |s| cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha)), + (BF16, F16, F32, F64) + ); Ok((storage, l.shape().clone())) } } @@ -69,10 +48,11 @@ fn custom_op1_no_backward() -> Result<()> { } // Define a similar struct as Elu but with backward support. -fn bwd(v: T, alpha: T) -> T { +fn bwd(v: T, alpha: f64) -> T { if v.is_sign_positive() { T::one() } else { + let alpha = T::from(alpha).unwrap_or(T::nan()); v.exp() * alpha } } @@ -87,33 +67,12 @@ impl CustomOp1 for EluBackward { } fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { - use CpuStorage::*; - - // In this example, we pattern match over the different dtypes. Some helper functions and - // traits from the `cpu_backend` module can be used to avoid this in some common cases, see - // e.g. `Map1`. - let storage = match s { - BF16(s) => { - let alpha = bf16::from_f64(self.alpha); - let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha)); - BF16(data) - } - F16(s) => { - let alpha = f16::from_f64(self.alpha); - let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha)); - F16(data) - } - F32(s) => { - let alpha = self.alpha as f32; - let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha)); - F32(data) - } - F64(s) => { - let data = cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha)); - F64(data) - } - _ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?, - }; + let storage = candle::map_dtype!( + "elu-bwd", + s, + |s| cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha)), + (BF16, F16, F32, F64) + ); Ok((storage, l.shape().clone())) } }