Add some fast Metal MLX SDPA kernels (#2584)

* Add some fast Metal MLX SDPA kernels (#32)

* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format

* Conditional compilation for bf16

* Use it in quantized llama

* Some review comments

* Use set_params!

* Remove unused

* Remove feature

* Fix metal sdpa for v stride

* Remove comma

* Add the dim method to layout and shape.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Eric Buehler
2024-11-05 03:28:00 -05:00
committed by GitHub
parent 6454597943
commit e2b6b367fa
7 changed files with 2006 additions and 14 deletions

View File

@ -964,3 +964,193 @@ impl Module for Identity {
Ok(xs.clone())
}
}
#[allow(dead_code)]
struct Sdpa {
scale: f32,
softcapping: f32,
}
impl candle::CustomOp3 for Sdpa {
fn name(&self) -> &'static str {
"metal-sdpa"
}
fn cpu_fwd(
&self,
_s1: &CpuStorage,
_l1: &Layout,
_s2: &CpuStorage,
_l2: &Layout,
_s3: &CpuStorage,
_l3: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("SDPA has no cpu impl")
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
q: &candle::MetalStorage,
q_l: &Layout,
k: &candle::MetalStorage,
k_l: &Layout,
v: &candle::MetalStorage,
v_l: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
use candle_metal_kernels::SdpaDType;
let device = q.device();
let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
let elem_count: usize = out_dims.iter().product();
let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
// q,k must have matching emb dim
if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
candle::bail!("`q` and `k` last dims must match");
}
// k,v must have matching n kv heads
if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
candle::bail!("`k` and `v` head dims must match");
}
// n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
}
let k_head = k_l.dim(D::Minus1)?;
let q_head = q_l.dim(D::Minus1)?;
let q_seq = q_l.dim(2)?;
let mut implementation_supports_use_case = q_head == k_head;
let supported_head_dim =
q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256;
const SDPA_FULL_THRESHOLD: usize = 2;
let supports_sdpa_full =
q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head;
let supports_sdpa_vector = q_seq == 1 && supported_head_dim;
implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
if !supported_head_dim {
candle::bail!(
"Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
q_l.dims(),
k_l.dims(),
v_l.dims()
);
}
if !implementation_supports_use_case {
candle::bail!(
"Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
q_l.dims(),
k_l.dims(),
v_l.dims()
);
}
for t in [k.dtype(), v.dtype()] {
if q.dtype() != t {
candle::bail!("all q, k, v dtypes must match.");
}
}
let itype = match q.dtype() {
DType::BF16 => SdpaDType::BF16,
DType::F16 => SdpaDType::F16,
DType::F32 => SdpaDType::F32,
other => candle::bail!("unsupported sdpa type {other:?}"),
};
let command_buffer = q.device().command_buffer()?;
if supports_sdpa_vector {
command_buffer.set_label("vector_attention");
candle_metal_kernels::call_sdpa_vector(
q.device().device(),
&command_buffer,
q.device().kernels(),
q_l.start_offset(),
q_l.dims(),
q.buffer(),
k_l.start_offset(),
k_l.dims(),
k_l.stride(),
k.buffer(),
v_l.start_offset(),
v_l.stride(),
v.buffer(),
&output,
self.scale,
self.softcapping,
itype,
)
.map_err(candle::Error::wrap)?;
} else if supports_sdpa_full {
if q_l.dim(2)? != k_l.dim(2)? {
candle::bail!(
"query and key sequence length must be equal if using full metal sdpa"
)
}
command_buffer.set_label("full_attention");
candle_metal_kernels::call_sdpa_full(
q.device().device(),
&command_buffer,
q.device().kernels(),
q_l.start_offset(),
q_l.dims(),
q.buffer(),
k_l.start_offset(),
k.buffer(),
v_l.start_offset(),
v.buffer(),
&output,
self.scale,
self.softcapping,
itype,
)
.map_err(candle::Error::wrap)?;
} else {
candle::bail!("must be vector or full sdpa kernel");
}
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
Ok((newstorage, Shape::from_dims(&out_dims)))
}
}
/// Scaled dot product attention with a fused kernel.
///
/// Computes softmax(qk^T*scale)v.
///
/// **Inputs shapes:**
/// - `q`: (bs, qhead, seq, hidden)
/// - `k`: (bs, kv_head, kv_seq, hidden)
/// - `k`: (bs, kv_head, kv_seq, v_hidden)
/// - `scale` is applied before softmax.
/// - If `softcapping` != 1.0:
/// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
///
/// **Output shape:** (bs, qhead, seq, v_hidden)
///
/// **Supported head dims:** 32, 64, 96, 128, 256.
///
/// ## On Metal:
/// - If `seq` == 1:
/// - Use a vectorized kernel
/// - Supports `seq` != `kv_seq` (cross attn. support)
/// - Supports GQA when `qhead` is a multiple of `kv_head`
/// - Otherwise:
/// - Use an alternate kernel
/// - Requires `seq` == `kv_seq`
/// - GQA is not supported (requires `qhead` == `kv_head`)
pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result<Tensor> {
q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping })
}

206
candle-nn/tests/sdpa.rs Normal file
View File

@ -0,0 +1,206 @@
#[cfg(feature = "metal")]
mod metal_sdpa_tests {
#[test]
fn sdpa_full() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
// Force seqlen = 100
const BS: usize = 4;
const R: usize = 4;
const L: usize = 4;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0005, "{}", error);
Ok(())
}
#[test]
fn sdpa_vector() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 1;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0001, "{}", error);
Ok(())
}
#[test]
fn sdpa_full_softcapping() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
use std::ops::{Div, Mul};
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 4;
const L: usize = 4;
const DK: usize = 64;
const H: usize = 3;
const SOFTCAP: f64 = 50.;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(
&att.to_dtype(DType::F32)?
.div(SOFTCAP)?
.tanh()?
.mul(SOFTCAP)?,
)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0004, "{}", error);
Ok(())
}
#[test]
fn sdpa_vector_softcapping() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
use std::ops::{Div, Mul};
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 1;
const DK: usize = 64;
const H: usize = 3;
const SOFTCAP: f64 = 50.;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(
&att.to_dtype(DType::F32)?
.div(SOFTCAP)?
.tanh()?
.mul(SOFTCAP)?,
)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0001, "{}", error);
Ok(())
}
#[test]
fn sdpa_vector_cross() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
// Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 24;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0013, "{}", error);
Ok(())
}
}