From d379a76a9ec6584fe010fb64246e734e46a0d899 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 13 Aug 2023 21:09:18 +0200 Subject: [PATCH] Add a softmax bench. (#433) * Add a softmax bench. * Add the vectorized sum reduce. --- candle-core/examples/cpu_benchmarks.rs | 30 +++++++++++++++++++++++++- candle-core/src/cpu_backend.rs | 20 ++++++++++------- candle-core/src/cpu_kernels.rs | 20 +++++++++++++++++ 3 files changed, 61 insertions(+), 9 deletions(-) diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-core/examples/cpu_benchmarks.rs index 421a4d50..6c40269f 100644 --- a/candle-core/examples/cpu_benchmarks.rs +++ b/candle-core/examples/cpu_benchmarks.rs @@ -5,9 +5,18 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle_core::{Device, Result, Tensor}; +use candle_core::{Device, Result, Tensor, D}; use clap::{Parser, Subcommand}; +fn softmax(xs: &Tensor, dim: D) -> Result { + let dim = dim.to_index(xs.shape(), "softmax")?; + let max = xs.max_keepdim(dim)?; + let diff = xs.broadcast_sub(&max)?; + let num = diff.exp()?; + let den = num.sum_keepdim(dim)?; + num.broadcast_div(&den) +} + trait Benchmark { type PreProcessData; type RunResult; @@ -72,6 +81,23 @@ impl Benchmark for Matmul { const ITERS: usize = 100; } +struct Softmax; +impl Benchmark for Softmax { + type PreProcessData = Tensor; + type RunResult = Tensor; + fn preprocess() -> Result { + // Typical whisper tiny size. + let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?; + Ok(x) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + softmax(d, D::Minus1) + } + + const ITERS: usize = 100; +} + fn run(iters: Option) -> Result<()> { use std::hint::black_box; @@ -90,6 +116,7 @@ enum Task { Conv1d, Conv2d, Matmul, + Softmax, } #[derive(Parser, Debug)] @@ -109,6 +136,7 @@ fn main() -> Result<()> { Task::Conv1d => run::(args.iters)?, Task::Conv2d => run::(args.iters)?, Task::Matmul => run::(args.iters)?, + Task::Softmax => run::(args.iters)?, } Ok(()) } diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 7eadd170..e6cc61a8 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -278,17 +278,17 @@ impl Map1Any for ReduceIndex { } } -struct Reduce<'a> { +struct ReduceSum<'a> { dst_shape: &'a Shape, reduce_dims: &'a [usize], reduce_dims_and_stride: Vec<(usize, usize)>, } -impl<'a> Reduce<'a> { +impl<'a> ReduceSum<'a> { #[inline(always)] fn fold_impl(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result> where - T: Clone + Copy, + T: WithDType, F: Fn(T, T) -> T, { let mut dst = vec![start_elt; self.dst_shape.elem_count()]; @@ -312,9 +312,13 @@ impl<'a> Reduce<'a> { .product::(); for (dst_i, dst_v) in dst.iter_mut().enumerate() { let src_i = dst_i * reduce_sz; - for &s in src[src_i..src_i + reduce_sz].iter() { - *dst_v = f(*dst_v, s) - } + unsafe { + T::vec_reduce_sum( + src[src_i..src_i + reduce_sz].as_ptr(), + dst_v, + reduce_sz, + ) + }; } return Ok(dst); }; @@ -346,7 +350,7 @@ impl<'a> Reduce<'a> { } } -impl<'a> Map1 for Reduce<'a> { +impl<'a> Map1 for ReduceSum<'a> { #[inline(always)] fn f(&self, src: &[T], src_l: &Layout) -> Result> { self.fold_impl(src, src_l, T::zero(), |x, y| x + y) @@ -1697,7 +1701,7 @@ impl BackendStorage for CpuStorage { .iter() .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::())) .collect(); - Reduce { + ReduceSum { dst_shape: &dst_shape, reduce_dims: &reduce_dims, reduce_dims_and_stride, diff --git a/candle-core/src/cpu_kernels.rs b/candle-core/src/cpu_kernels.rs index 75509ba9..edcbe740 100644 --- a/candle-core/src/cpu_kernels.rs +++ b/candle-core/src/cpu_kernels.rs @@ -12,6 +12,20 @@ pub trait VecDot: num_traits::NumAssign + Copy { *res += *lhs.add(i) * *rhs.add(i) } } + + /// Sum of all elements in a vector. + /// + /// # Safety + /// + /// The length of `xs` must be at least `len`. `res` has to point to a valid + /// element. + #[inline(always)] + unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) { + *res = Self::zero(); + for i in 0..len { + *res += *xs.add(i) + } + } } impl VecDot for f32 { @@ -19,6 +33,12 @@ impl VecDot for f32 { unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { ggblas::ggml::vec_dot_f32(lhs, rhs, res, len) } + + // TODO: enable the following once the updated ggblas is available. + // #[inline(always)] + // unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) { + // ggblas::ggml::vec_reduce_sum(xs, res, len) + // } } impl VecDot for f64 {}