From 1c9e5394a5056aadc948f9330ea31fea4972e65e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 5 Sep 2023 15:20:23 +0200 Subject: [PATCH] Add a custom softmax implementation. (#744) * Add a custom softmax implementation. * Add softmaxlastdim to the benchmarks. * And add a test. * Support more dtypes. * Polish the code. * Use the slow implementation on cuda. * Add a todo for the cuda kernel. --- .../examples/stable-diffusion/attention.rs | 2 +- candle-nn/Cargo.toml | 4 ++ .../examples/cpu_benchmarks.rs | 42 ++++++----- candle-nn/src/ops.rs | 69 ++++++++++++++++++- candle-nn/tests/ops.rs | 10 +++ 5 files changed, 109 insertions(+), 18 deletions(-) rename {candle-core => candle-nn}/examples/cpu_benchmarks.rs (80%) diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs index 1ae1bfc3..000cd2fe 100644 --- a/candle-examples/examples/stable-diffusion/attention.rs +++ b/candle-examples/examples/stable-diffusion/attention.rs @@ -198,7 +198,7 @@ impl CrossAttention { let xs = query.matmul(&(key.t()? * self.scale)?)?; let xs = { let _enter = self.span_softmax.enter(); - nn::ops::softmax(&xs, D::Minus1)? + nn::ops::softmax_last_dim(&xs)? }; xs.matmul(&value)?.to_dtype(in_dtype)? }; diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index aa055583..db0f6a8f 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -12,12 +12,16 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } +half = { workspace = true } thiserror = { workspace = true } intel-mkl-src = { workspace = true, optional = true } +num-traits = { workspace = true } +rayon = { workspace = true } safetensors = { workspace = true } [dev-dependencies] anyhow = { workspace = true } +clap = { workspace = true } [features] default = [] diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs similarity index 80% rename from candle-core/examples/cpu_benchmarks.rs rename to candle-nn/examples/cpu_benchmarks.rs index 13175ac1..20c92dbb 100644 --- a/candle-core/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -5,19 +5,10 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle_core::quantized::GgmlType; -use candle_core::{Device, Result, Tensor, D}; +use candle::quantized::GgmlType; +use candle::{Device, Result, Tensor, D}; use clap::{Parser, Subcommand}; -fn softmax(xs: &Tensor, dim: D) -> Result { - let dim = dim.to_index(xs.shape(), "softmax")?; - let max = xs.max_keepdim(dim)?; - let diff = xs.broadcast_sub(&max)?; - let num = diff.exp()?; - let den = num.sum_keepdim(dim)?; - num.broadcast_div(&den) -} - trait Benchmark { type PreProcessData; type RunResult; @@ -86,12 +77,12 @@ impl Benchmark for Matmul { // https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp struct QMatMul; impl Benchmark for QMatMul { - type PreProcessData = (candle_core::quantized::QMatMul, Tensor); + type PreProcessData = (candle::quantized::QMatMul, Tensor); type RunResult = Tensor; fn preprocess() -> Result { - let zeros = vec![candle_core::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32]; - let mm = candle_core::quantized::QTensor::new(zeros, (4096, 11008))?; - let mm = candle_core::quantized::QMatMul::from_qtensor(mm); + let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32]; + let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?; + let mm = candle::quantized::QMatMul::from_qtensor(mm); let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?; Ok((mm, arg)) } @@ -114,7 +105,24 @@ impl Benchmark for Softmax { } fn run_one(d: &Self::PreProcessData) -> Result { - softmax(d, D::Minus1) + candle_nn::ops::softmax(d, D::Minus1) + } + + const ITERS: usize = 100; +} + +struct SoftmaxLastDim; +impl Benchmark for SoftmaxLastDim { + type PreProcessData = Tensor; + type RunResult = Tensor; + fn preprocess() -> Result { + // Typical whisper tiny size. + let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?; + Ok(x) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + candle_nn::ops::softmax_last_dim(d) } const ITERS: usize = 100; @@ -140,6 +148,7 @@ enum Task { Matmul, Qmatmul, Softmax, + SoftmaxLastDim, } #[derive(Parser, Debug)] @@ -160,6 +169,7 @@ fn main() -> Result<()> { Task::Conv2d => run::(args.iters)?, Task::Matmul => run::(args.iters)?, Task::Softmax => run::(args.iters)?, + Task::SoftmaxLastDim => run::(args.iters)?, Task::Qmatmul => run::(args.iters)?, } Ok(()) diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index c3b6ffa2..55da46f8 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,4 +1,5 @@ -use candle::{Result, Tensor}; +use candle::{CpuStorage, Layout, Result, Shape, Tensor}; +use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on /// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. @@ -77,3 +78,69 @@ impl Dropout { } } } + +struct SoftmaxLastDim; + +impl candle::CustomOp1 for SoftmaxLastDim { + fn name(&self) -> &'static str { + "softmax-last-dim" + } + + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { + fn softmax( + src: &[T], + layout: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let el_count = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(dim_m1) + .zip(dst.par_chunks_mut(dim_m1)) + .for_each(|(src, dst)| { + let mut max = T::neg_infinity(); + for &s in src.iter() { + max = T::max(s, max) + } + let mut sum_exp = T::zero(); + for (s, d) in src.iter().zip(dst.iter_mut()) { + *d = (*s - max).exp(); + sum_exp += *d + } + for d in dst.iter_mut() { + *d /= sum_exp + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, Shape::from_dims(dims))) + } + + match storage { + CpuStorage::BF16(slice) => softmax::(slice, layout), + CpuStorage::F16(slice) => softmax::(slice, layout), + CpuStorage::F32(slice) => softmax::(slice, layout), + CpuStorage::F64(slice) => softmax::(slice, layout), + _ => candle::bail!("unsupported dtype for softmax {:?}", storage), + } + } + + fn cuda_fwd( + &self, + _storage: &candle::CudaStorage, + _layout: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + candle::bail!("TODO: implement a cuda kernel") + } +} + +pub fn softmax_last_dim(xs: &Tensor) -> Result { + if xs.device().is_cpu() { + xs.apply_op1_no_bwd(&SoftmaxLastDim) + } else { + softmax(xs, candle::D::Minus1) + } +} diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 4ba8cfcc..5ca01b37 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -41,6 +41,16 @@ fn softmax() -> Result<()> { [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]] ] ); + let t2 = candle_nn::ops::softmax_last_dim(&tensor.log()?)?; + assert_eq!( + to_vec3_round(&t2, 4)?, + &[ + // (3, 1, 4) / 8, (1, 5, 9) / 15 + [[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]], + // (2, 1, 7) / 10, (8, 2, 8) / 18 + [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]] + ] + ); Ok(()) }