mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add a custom softmax implementation. (#744)
* Add a custom softmax implementation. * Add softmaxlastdim to the benchmarks. * And add a test. * Support more dtypes. * Polish the code. * Use the slow implementation on cuda. * Add a todo for the cuda kernel.
This commit is contained in:
@ -198,7 +198,7 @@ impl CrossAttention {
|
|||||||
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
||||||
let xs = {
|
let xs = {
|
||||||
let _enter = self.span_softmax.enter();
|
let _enter = self.span_softmax.enter();
|
||||||
nn::ops::softmax(&xs, D::Minus1)?
|
nn::ops::softmax_last_dim(&xs)?
|
||||||
};
|
};
|
||||||
xs.matmul(&value)?.to_dtype(in_dtype)?
|
xs.matmul(&value)?.to_dtype(in_dtype)?
|
||||||
};
|
};
|
||||||
|
@ -12,12 +12,16 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||||
|
half = { workspace = true }
|
||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
|
num-traits = { workspace = true }
|
||||||
|
rayon = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
|
clap = { workspace = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
@ -5,19 +5,10 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle_core::quantized::GgmlType;
|
use candle::quantized::GgmlType;
|
||||||
use candle_core::{Device, Result, Tensor, D};
|
use candle::{Device, Result, Tensor, D};
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
fn softmax<D: candle_core::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
|
|
||||||
let dim = dim.to_index(xs.shape(), "softmax")?;
|
|
||||||
let max = xs.max_keepdim(dim)?;
|
|
||||||
let diff = xs.broadcast_sub(&max)?;
|
|
||||||
let num = diff.exp()?;
|
|
||||||
let den = num.sum_keepdim(dim)?;
|
|
||||||
num.broadcast_div(&den)
|
|
||||||
}
|
|
||||||
|
|
||||||
trait Benchmark {
|
trait Benchmark {
|
||||||
type PreProcessData;
|
type PreProcessData;
|
||||||
type RunResult;
|
type RunResult;
|
||||||
@ -86,12 +77,12 @@ impl Benchmark for Matmul {
|
|||||||
// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
|
// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp
|
||||||
struct QMatMul;
|
struct QMatMul;
|
||||||
impl Benchmark for QMatMul {
|
impl Benchmark for QMatMul {
|
||||||
type PreProcessData = (candle_core::quantized::QMatMul, Tensor);
|
type PreProcessData = (candle::quantized::QMatMul, Tensor);
|
||||||
type RunResult = Tensor;
|
type RunResult = Tensor;
|
||||||
fn preprocess() -> Result<Self::PreProcessData> {
|
fn preprocess() -> Result<Self::PreProcessData> {
|
||||||
let zeros = vec![candle_core::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
||||||
let mm = candle_core::quantized::QTensor::new(zeros, (4096, 11008))?;
|
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
|
||||||
let mm = candle_core::quantized::QMatMul::from_qtensor(mm);
|
let mm = candle::quantized::QMatMul::from_qtensor(mm);
|
||||||
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
||||||
Ok((mm, arg))
|
Ok((mm, arg))
|
||||||
}
|
}
|
||||||
@ -114,7 +105,24 @@ impl Benchmark for Softmax {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||||
softmax(d, D::Minus1)
|
candle_nn::ops::softmax(d, D::Minus1)
|
||||||
|
}
|
||||||
|
|
||||||
|
const ITERS: usize = 100;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SoftmaxLastDim;
|
||||||
|
impl Benchmark for SoftmaxLastDim {
|
||||||
|
type PreProcessData = Tensor;
|
||||||
|
type RunResult = Tensor;
|
||||||
|
fn preprocess() -> Result<Self::PreProcessData> {
|
||||||
|
// Typical whisper tiny size.
|
||||||
|
let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||||
|
candle_nn::ops::softmax_last_dim(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
const ITERS: usize = 100;
|
const ITERS: usize = 100;
|
||||||
@ -140,6 +148,7 @@ enum Task {
|
|||||||
Matmul,
|
Matmul,
|
||||||
Qmatmul,
|
Qmatmul,
|
||||||
Softmax,
|
Softmax,
|
||||||
|
SoftmaxLastDim,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -160,6 +169,7 @@ fn main() -> Result<()> {
|
|||||||
Task::Conv2d => run::<Conv2d>(args.iters)?,
|
Task::Conv2d => run::<Conv2d>(args.iters)?,
|
||||||
Task::Matmul => run::<Matmul>(args.iters)?,
|
Task::Matmul => run::<Matmul>(args.iters)?,
|
||||||
Task::Softmax => run::<Softmax>(args.iters)?,
|
Task::Softmax => run::<Softmax>(args.iters)?,
|
||||||
|
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
|
||||||
Task::Qmatmul => run::<QMatMul>(args.iters)?,
|
Task::Qmatmul => run::<QMatMul>(args.iters)?,
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
@ -1,4 +1,5 @@
|
|||||||
use candle::{Result, Tensor};
|
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||||
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
|
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
|
||||||
@ -77,3 +78,69 @@ impl Dropout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct SoftmaxLastDim;
|
||||||
|
|
||||||
|
impl candle::CustomOp1 for SoftmaxLastDim {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"softmax-last-dim"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||||
|
fn softmax<T: candle::WithDType + num_traits::Float>(
|
||||||
|
src: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)> {
|
||||||
|
let src = match layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("input has to be contiguous"),
|
||||||
|
Some((o1, o2)) => &src[o1..o2],
|
||||||
|
};
|
||||||
|
let el_count = layout.shape().elem_count();
|
||||||
|
let dims = layout.shape().dims();
|
||||||
|
let dim_m1 = dims[dims.len() - 1];
|
||||||
|
let mut dst = vec![T::zero(); el_count];
|
||||||
|
src.par_chunks(dim_m1)
|
||||||
|
.zip(dst.par_chunks_mut(dim_m1))
|
||||||
|
.for_each(|(src, dst)| {
|
||||||
|
let mut max = T::neg_infinity();
|
||||||
|
for &s in src.iter() {
|
||||||
|
max = T::max(s, max)
|
||||||
|
}
|
||||||
|
let mut sum_exp = T::zero();
|
||||||
|
for (s, d) in src.iter().zip(dst.iter_mut()) {
|
||||||
|
*d = (*s - max).exp();
|
||||||
|
sum_exp += *d
|
||||||
|
}
|
||||||
|
for d in dst.iter_mut() {
|
||||||
|
*d /= sum_exp
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||||
|
Ok((storage, Shape::from_dims(dims)))
|
||||||
|
}
|
||||||
|
|
||||||
|
match storage {
|
||||||
|
CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
|
||||||
|
CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
|
||||||
|
CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
|
||||||
|
CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
|
||||||
|
_ => candle::bail!("unsupported dtype for softmax {:?}", storage),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
_storage: &candle::CudaStorage,
|
||||||
|
_layout: &Layout,
|
||||||
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
|
candle::bail!("TODO: implement a cuda kernel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||||
|
if xs.device().is_cpu() {
|
||||||
|
xs.apply_op1_no_bwd(&SoftmaxLastDim)
|
||||||
|
} else {
|
||||||
|
softmax(xs, candle::D::Minus1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -41,6 +41,16 @@ fn softmax() -> Result<()> {
|
|||||||
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
|
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
let t2 = candle_nn::ops::softmax_last_dim(&tensor.log()?)?;
|
||||||
|
assert_eq!(
|
||||||
|
to_vec3_round(&t2, 4)?,
|
||||||
|
&[
|
||||||
|
// (3, 1, 4) / 8, (1, 5, 9) / 15
|
||||||
|
[[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
|
||||||
|
// (2, 1, 7) / 10, (8, 2, 8) / 18
|
||||||
|
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
|
||||||
|
]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user