Add some fast Metal MLX SDPA kernels (#2584)

* Add some fast Metal MLX SDPA kernels (#32)

* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format

* Conditional compilation for bf16

* Use it in quantized llama

* Some review comments

* Use set_params!

* Remove unused

* Remove feature

* Fix metal sdpa for v stride

* Remove comma

* Add the dim method to layout and shape.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Eric Buehler
2024-11-05 03:28:00 -05:00
committed by GitHub
parent 6454597943
commit e2b6b367fa
7 changed files with 2006 additions and 14 deletions

View File

@ -35,6 +35,12 @@ impl Layout {
self.shape.dims()
}
/// The dimension size for a specified dimension index.
pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
let dim = dim.to_index(&self.shape, "dim")?;
Ok(self.dims()[dim])
}
pub fn shape(&self) -> &Shape {
&self.shape
}

View File

@ -142,6 +142,12 @@ impl Shape {
&self.0
}
/// The dimension size for a specified dimension index.
pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
let dim = dim.to_index(self, "dim")?;
Ok(self.dims()[dim])
}
/// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize {
self.0.iter().product()

View File

@ -8,7 +8,7 @@ use std::sync::RwLock;
pub mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderProvider};
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
const BINARY: &str = include_str!("binary.metal");
@ -25,6 +25,7 @@ const REDUCE: &str = include_str!("reduce.metal");
const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
@ -42,6 +43,7 @@ pub enum Source {
Sort,
Ternary,
Unary,
Sdpa,
}
pub mod copy2d {
@ -159,6 +161,17 @@ pub enum MetalKernelError {
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
#[error("Sdpa {variation} head size was {got}, expectd {expected:?}")]
SdpaHeadSizeMismatch {
variation: &'static str,
got: usize,
expected: Vec<usize>,
},
#[error("Sdpa {variation} got dtype {got:?}")]
SdpaHeadDTypeMismatch {
variation: &'static str,
got: SdpaDType,
},
}
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@ -207,6 +220,7 @@ impl Kernels {
Source::Sort => SORT,
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Sdpa => SDPA,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -1627,6 +1641,313 @@ pub fn call_gemm(
Ok(())
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum SdpaDType {
BF16,
F16,
F32,
}
/// SDPA full is supported when:
/// - q head dim == 64, 128
/// - no mask
/// - q heads == kv heads
/// - final type != bf16 (TODO maybe just template this kernel too?)
/// - q,k,v are contiguous
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_full(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_buffer: &Buffer,
v_offset: usize,
v_buffer: &Buffer,
output: &Buffer,
alpha: f32,
softcapping: f32,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
#[derive(Debug)]
#[repr(C)]
struct MLXFastAttentionParams {
m: i32,
n: i32,
k: i32,
ldq: i32, // ldq == ldo
ldk: i32,
ldv: i32,
lds: i32,
ldo: i32,
tiles_n: i32,
tiles_m: i32,
batch_stride_q: i32,
batch_stride_k: i32,
batch_stride_v: i32,
batch_stride_o: i32,
swizzle_log: i32,
gemm_n_iterations_aligned: i32,
gemm_k_iterations_aligned: i32,
gemm_sv_m_block_iterations: i32,
batch_ndim: i32,
alpha: f32,
softcapping: f32,
}
let bk = q_shape.last().unwrap();
const BN: usize = 16;
const BM: usize = 16;
const WM: usize = 2;
const WN: usize = 2;
let name = match (bk, itype) {
(32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half",
(64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half",
(96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half",
(128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half",
(256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half",
(32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float",
(64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float",
(96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float",
(128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float",
(256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float",
(other, SdpaDType::F16 | SdpaDType::F32) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "full",
got: *other,
expected: vec![32, 64, 96, 128, 256],
})
}
(_, SdpaDType::BF16) => {
return Err(MetalKernelError::SdpaHeadDTypeMismatch {
variation: "full",
got: SdpaDType::BF16,
})
}
};
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, seq, hidden)
let qseq = q_shape[q_shape.len() - 2];
let m = q_shape[q_shape.len() - 2];
let n = m;
let k = q_shape[q_shape.len() - 1];
let bs_out = q_shape[0] * q_shape[1];
let batch_shape = [q_shape[0] * q_shape[1]];
let dk = q_shape[q_shape.len() - 1];
let ldq = dk;
let ldk = dk;
let ldv = dk;
let lds = BN;
let ldo = dk;
let tn = 1;
let tm = (m + BM - 1) / BM;
let b_stride_q = dk * qseq;
let b_stride_k = dk * qseq;
let b_stride_v = dk * qseq;
let b_stride_o = dk * qseq;
let swizzle_log = 0;
let gemm_n_iterations_aligned = (n + BN - 1) / BN;
let gemm_k_iterations_aligned = (k + bk - 1) / bk;
let gemm_sv_m_block_iterations = (m + BM - 1) / BM;
let batch_ndim = batch_shape.len();
let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
};
let params = MLXFastAttentionParams {
m: m as i32,
n: n as i32,
k: k as i32,
ldq: ldq as i32,
ldk: ldk as i32,
ldv: ldv as i32,
lds: lds as i32,
ldo: ldo as i32,
tiles_n: tn,
tiles_m: tm as i32,
batch_stride_q: b_stride_q as i32,
batch_stride_k: b_stride_k as i32,
batch_stride_v: b_stride_v as i32,
batch_stride_o: b_stride_o as i32,
swizzle_log,
gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32,
gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32,
gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32,
batch_ndim: batch_ndim as i32,
alpha,
softcapping,
};
let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o];
impl EncoderParam for MLXFastAttentionParams {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<MLXFastAttentionParams>() as u64,
&data as *const MLXFastAttentionParams as *const c_void,
);
}
}
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
params,
&batch_shape[..],
&batch_strides[..]
)
);
let grid_dims = MTLSize {
width: 1,
height: tm as u64,
depth: bs_out as u64,
};
let group_dims = MTLSize {
width: 32,
height: WM as u64,
depth: WN as u64,
};
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}
/// SDPA full is supported when:
/// - q head dim == 64, 96, 128
/// - no mask
/// - q,k,v are contiguous
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_vector(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_shape: &[usize],
k_stride: &[usize],
k_buffer: &Buffer,
v_offset: usize,
v_stride: &[usize],
v_buffer: &Buffer,
output: &Buffer,
alpha: f32,
softcapping: f32,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
let bk = q_shape.last().unwrap();
let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
let n = k_shape[2] as i32;
let b = (q_shape[0] * q_shape[1]) as i32;
let kstride = k_stride[1];
let vstride = v_stride[1];
let name = match (bk, itype) {
(32, SdpaDType::F16) => "sdpa_vector_float16_t_32",
(64, SdpaDType::F16) => "sdpa_vector_float16_t_64",
(96, SdpaDType::F16) => "sdpa_vector_float16_t_96",
(128, SdpaDType::F16) => "sdpa_vector_float16_t_128",
(256, SdpaDType::F16) => "sdpa_vector_float16_t_256",
(32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32",
(64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64",
(96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96",
(128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128",
(256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256",
(32, SdpaDType::F32) => "sdpa_vector_float_32",
(64, SdpaDType::F32) => "sdpa_vector_float_64",
(96, SdpaDType::F32) => "sdpa_vector_float_96",
(128, SdpaDType::F32) => "sdpa_vector_float_128",
(256, SdpaDType::F32) => "sdpa_vector_float_256",
(other, _) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "vector",
got: *other,
expected: vec![32, 64, 96, 128, 256],
})
}
};
let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
};
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, kv_seq, hidden)
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
gqa_factor,
n,
kstride,
vstride,
alpha,
softcapping
)
);
let grid_dims = MTLSize {
width: 1,
height: b as u64,
depth: 1 as u64,
};
let group_dims = MTLSize {
width: 1024,
height: 1,
depth: 1,
};
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,

File diff suppressed because it is too large Load Diff

View File

@ -964,3 +964,193 @@ impl Module for Identity {
Ok(xs.clone())
}
}
#[allow(dead_code)]
struct Sdpa {
scale: f32,
softcapping: f32,
}
impl candle::CustomOp3 for Sdpa {
fn name(&self) -> &'static str {
"metal-sdpa"
}
fn cpu_fwd(
&self,
_s1: &CpuStorage,
_l1: &Layout,
_s2: &CpuStorage,
_l2: &Layout,
_s3: &CpuStorage,
_l3: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("SDPA has no cpu impl")
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
q: &candle::MetalStorage,
q_l: &Layout,
k: &candle::MetalStorage,
k_l: &Layout,
v: &candle::MetalStorage,
v_l: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
use candle_metal_kernels::SdpaDType;
let device = q.device();
let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
let elem_count: usize = out_dims.iter().product();
let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
// q,k must have matching emb dim
if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
candle::bail!("`q` and `k` last dims must match");
}
// k,v must have matching n kv heads
if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
candle::bail!("`k` and `v` head dims must match");
}
// n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
}
let k_head = k_l.dim(D::Minus1)?;
let q_head = q_l.dim(D::Minus1)?;
let q_seq = q_l.dim(2)?;
let mut implementation_supports_use_case = q_head == k_head;
let supported_head_dim =
q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256;
const SDPA_FULL_THRESHOLD: usize = 2;
let supports_sdpa_full =
q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head;
let supports_sdpa_vector = q_seq == 1 && supported_head_dim;
implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
if !supported_head_dim {
candle::bail!(
"Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
q_l.dims(),
k_l.dims(),
v_l.dims()
);
}
if !implementation_supports_use_case {
candle::bail!(
"Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
q_l.dims(),
k_l.dims(),
v_l.dims()
);
}
for t in [k.dtype(), v.dtype()] {
if q.dtype() != t {
candle::bail!("all q, k, v dtypes must match.");
}
}
let itype = match q.dtype() {
DType::BF16 => SdpaDType::BF16,
DType::F16 => SdpaDType::F16,
DType::F32 => SdpaDType::F32,
other => candle::bail!("unsupported sdpa type {other:?}"),
};
let command_buffer = q.device().command_buffer()?;
if supports_sdpa_vector {
command_buffer.set_label("vector_attention");
candle_metal_kernels::call_sdpa_vector(
q.device().device(),
&command_buffer,
q.device().kernels(),
q_l.start_offset(),
q_l.dims(),
q.buffer(),
k_l.start_offset(),
k_l.dims(),
k_l.stride(),
k.buffer(),
v_l.start_offset(),
v_l.stride(),
v.buffer(),
&output,
self.scale,
self.softcapping,
itype,
)
.map_err(candle::Error::wrap)?;
} else if supports_sdpa_full {
if q_l.dim(2)? != k_l.dim(2)? {
candle::bail!(
"query and key sequence length must be equal if using full metal sdpa"
)
}
command_buffer.set_label("full_attention");
candle_metal_kernels::call_sdpa_full(
q.device().device(),
&command_buffer,
q.device().kernels(),
q_l.start_offset(),
q_l.dims(),
q.buffer(),
k_l.start_offset(),
k.buffer(),
v_l.start_offset(),
v.buffer(),
&output,
self.scale,
self.softcapping,
itype,
)
.map_err(candle::Error::wrap)?;
} else {
candle::bail!("must be vector or full sdpa kernel");
}
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
Ok((newstorage, Shape::from_dims(&out_dims)))
}
}
/// Scaled dot product attention with a fused kernel.
///
/// Computes softmax(qk^T*scale)v.
///
/// **Inputs shapes:**
/// - `q`: (bs, qhead, seq, hidden)
/// - `k`: (bs, kv_head, kv_seq, hidden)
/// - `k`: (bs, kv_head, kv_seq, v_hidden)
/// - `scale` is applied before softmax.
/// - If `softcapping` != 1.0:
/// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
///
/// **Output shape:** (bs, qhead, seq, v_hidden)
///
/// **Supported head dims:** 32, 64, 96, 128, 256.
///
/// ## On Metal:
/// - If `seq` == 1:
/// - Use a vectorized kernel
/// - Supports `seq` != `kv_seq` (cross attn. support)
/// - Supports GQA when `qhead` is a multiple of `kv_head`
/// - Otherwise:
/// - Use an alternate kernel
/// - Requires `seq` == `kv_seq`
/// - GQA is not supported (requires `qhead` == `kv_head`)
pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result<Tensor> {
q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping })
}

206
candle-nn/tests/sdpa.rs Normal file
View File

@ -0,0 +1,206 @@
#[cfg(feature = "metal")]
mod metal_sdpa_tests {
#[test]
fn sdpa_full() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
// Force seqlen = 100
const BS: usize = 4;
const R: usize = 4;
const L: usize = 4;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0005, "{}", error);
Ok(())
}
#[test]
fn sdpa_vector() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 1;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0001, "{}", error);
Ok(())
}
#[test]
fn sdpa_full_softcapping() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
use std::ops::{Div, Mul};
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 4;
const L: usize = 4;
const DK: usize = 64;
const H: usize = 3;
const SOFTCAP: f64 = 50.;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(
&att.to_dtype(DType::F32)?
.div(SOFTCAP)?
.tanh()?
.mul(SOFTCAP)?,
)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0004, "{}", error);
Ok(())
}
#[test]
fn sdpa_vector_softcapping() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
use std::ops::{Div, Mul};
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 1;
const DK: usize = 64;
const H: usize = 3;
const SOFTCAP: f64 = 50.;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(
&att.to_dtype(DType::F32)?
.div(SOFTCAP)?
.tanh()?
.mul(SOFTCAP)?,
)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0001, "{}", error);
Ok(())
}
#[test]
fn sdpa_vector_cross() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
// Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 24;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0013, "{}", error);
Ok(())
}
}

View File

@ -205,21 +205,27 @@ impl LayerWeights {
};
self.kv_cache = Some((k.clone(), v.clone()));
// Support for MQA, useful for 70B models and mistral.
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
let y = if q.device().is_metal() && seq_len == 1 {
// SDPA will do MQA for us
candle_nn::ops::sdpa(&q, &k, &v, 1. / (self.head_dim as f32).sqrt(), 1.)?
} else {
// Support for MQA, useful for 70B models and mistral.
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
}
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
}
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.attention_wo.forward(&y)?;
Ok(y)