From 28e1c0730401ff4d2ef180782a9ba9fb5dadcf7e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 17 Jul 2023 10:22:33 +0100 Subject: [PATCH] Process unary functions per block (#180) * Process unary functions per block. * Add some inline hints. --- candle-core/src/cpu_backend.rs | 32 +++++++++++++++++++++++++++++--- candle-core/src/op.rs | 24 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 97e46e74..a2944166 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -98,6 +98,7 @@ struct Sum<'a> { } impl<'a> Map1 for Sum<'a> { + #[inline(always)] fn f(&self, src: &[T], src_layout: &Layout) -> Result> { let mut dst = vec![T::zero(); self.dst_shape.elem_count()]; for (unstr_index, src_index) in src_layout.strided_index().enumerate() { @@ -115,10 +116,35 @@ impl<'a> Map1 for Sum<'a> { } fn unary_map U>(vs: &[T], layout: &Layout, mut f: F) -> Vec { - match layout.contiguous_offsets() { - Some((o1, o2)) => vs[o1..o2].iter().map(|&v| f(v)).collect(), - None => layout.strided_index().map(|i| f(vs[i])).collect(), + let mut result = vec![]; + result.reserve(layout.shape().elem_count()); + match layout.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => { + for &v in vs[start_offset..start_offset + len].iter() { + result.push(f(v)) + } + } + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + // Specialize the case where block_len is one to avoid the second loop. + if block_len == 1 { + for index in block_start_index { + let v = unsafe { vs.get_unchecked(index) }; + result.push(f(*v)) + } + } else { + for index in block_start_index { + for offset in 0..block_len { + let v = unsafe { vs.get_unchecked(index + offset) }; + result.push(f(*v)) + } + } + } + } } + result } // This function maps over two strided index sequences. diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 1b2d800d..79473d2a 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -95,21 +95,27 @@ macro_rules! bin_op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("b", $name); const V: Self = $op; + #[inline(always)] fn bf16(v1: bf16, v2: bf16) -> bf16 { $e(v1, v2) } + #[inline(always)] fn f16(v1: f16, v2: f16) -> f16 { $e(v1, v2) } + #[inline(always)] fn f32(v1: f32, v2: f32) -> f32 { $e(v1, v2) } + #[inline(always)] fn f64(v1: f64, v2: f64) -> f64 { $e(v1, v2) } + #[inline(always)] fn u8(v1: u8, v2: u8) -> u8 { $e(v1, v2) } + #[inline(always)] fn u32(v1: u32, v2: u32) -> u32 { $e(v1, v2) } @@ -128,21 +134,27 @@ macro_rules! unary_op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("u", $name); const V: Self = $op; + #[inline(always)] fn bf16($a: bf16) -> bf16 { $e } + #[inline(always)] fn f16($a: f16) -> f16 { $e } + #[inline(always)] fn f32($a: f32) -> f32 { $e } + #[inline(always)] fn f64($a: f64) -> f64 { $e } + #[inline(always)] fn u8(_: u8) -> u8 { todo!("no unary function for u8") } + #[inline(always)] fn u32(_: u32) -> u32 { todo!("no unary function for u32") } @@ -164,6 +176,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt()); impl UnaryOp for Gelu { const NAME: &'static str = "gelu"; const V: Self = Gelu; + #[inline(always)] fn bf16(v: bf16) -> bf16 { bf16::from_f32_const(0.5) * v @@ -174,6 +187,7 @@ impl UnaryOp for Gelu { * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v), )) } + #[inline(always)] fn f16(v: f16) -> f16 { f16::from_f32_const(0.5) * v @@ -184,19 +198,23 @@ impl UnaryOp for Gelu { * (f16::ONE + f16::from_f32_const(0.044715) * v * v), )) } + #[inline(always)] fn f32(v: f32) -> f32 { 0.5 * v * (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) } + #[inline(always)] fn f64(v: f64) -> f64 { 0.5 * v * (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) } + #[inline(always)] fn u8(_: u8) -> u8 { 0 } + #[inline(always)] fn u32(_: u32) -> u32 { 0 } @@ -207,21 +225,27 @@ impl UnaryOp for Relu { const NAME: &'static str = "relu"; const KERNEL: &'static str = "urelu"; const V: Self = Relu; + #[inline(always)] fn bf16(v: bf16) -> bf16 { v.max(bf16::ZERO) } + #[inline(always)] fn f16(v: f16) -> f16 { v.max(f16::ZERO) } + #[inline(always)] fn f32(v: f32) -> f32 { v.max(0f32) } + #[inline(always)] fn f64(v: f64) -> f64 { v.max(0f64) } + #[inline(always)] fn u8(v: u8) -> u8 { v } + #[inline(always)] fn u32(v: u32) -> u32 { v }