mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32) * Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format * Conditional compilation for bf16 * Use it in quantized llama * Some review comments * Use set_params! * Remove unused * Remove feature * Fix metal sdpa for v stride * Remove comma * Add the dim method to layout and shape. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -35,6 +35,12 @@ impl Layout {
|
||||
self.shape.dims()
|
||||
}
|
||||
|
||||
/// The dimension size for a specified dimension index.
|
||||
pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
|
||||
let dim = dim.to_index(&self.shape, "dim")?;
|
||||
Ok(self.dims()[dim])
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.shape
|
||||
}
|
||||
|
@ -142,6 +142,12 @@ impl Shape {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// The dimension size for a specified dimension index.
|
||||
pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
|
||||
let dim = dim.to_index(self, "dim")?;
|
||||
Ok(self.dims()[dim])
|
||||
}
|
||||
|
||||
/// The total number of elements, this is the product of all dimension sizes.
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
|
@ -8,7 +8,7 @@ use std::sync::RwLock;
|
||||
|
||||
pub mod utils;
|
||||
pub use utils::BufferOffset;
|
||||
use utils::{get_block_dims, linear_split, EncoderProvider};
|
||||
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
@ -25,6 +25,7 @@ const REDUCE: &str = include_str!("reduce.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
@ -42,6 +43,7 @@ pub enum Source {
|
||||
Sort,
|
||||
Ternary,
|
||||
Unary,
|
||||
Sdpa,
|
||||
}
|
||||
|
||||
pub mod copy2d {
|
||||
@ -159,6 +161,17 @@ pub enum MetalKernelError {
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
#[error("Sdpa {variation} head size was {got}, expectd {expected:?}")]
|
||||
SdpaHeadSizeMismatch {
|
||||
variation: &'static str,
|
||||
got: usize,
|
||||
expected: Vec<usize>,
|
||||
},
|
||||
#[error("Sdpa {variation} got dtype {got:?}")]
|
||||
SdpaHeadDTypeMismatch {
|
||||
variation: &'static str,
|
||||
got: SdpaDType,
|
||||
},
|
||||
}
|
||||
|
||||
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||
@ -207,6 +220,7 @@ impl Kernels {
|
||||
Source::Sort => SORT,
|
||||
Source::Ternary => TERNARY,
|
||||
Source::Unary => UNARY,
|
||||
Source::Sdpa => SDPA,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
@ -1627,6 +1641,313 @@ pub fn call_gemm(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum SdpaDType {
|
||||
BF16,
|
||||
F16,
|
||||
F32,
|
||||
}
|
||||
|
||||
/// SDPA full is supported when:
|
||||
/// - q head dim == 64, 128
|
||||
/// - no mask
|
||||
/// - q heads == kv heads
|
||||
/// - final type != bf16 (TODO maybe just template this kernel too?)
|
||||
/// - q,k,v are contiguous
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_sdpa_full(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
q_offset: usize,
|
||||
q_shape: &[usize],
|
||||
q_buffer: &Buffer,
|
||||
k_offset: usize,
|
||||
k_buffer: &Buffer,
|
||||
v_offset: usize,
|
||||
v_buffer: &Buffer,
|
||||
output: &Buffer,
|
||||
alpha: f32,
|
||||
softcapping: f32,
|
||||
itype: SdpaDType,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
#[derive(Debug)]
|
||||
#[repr(C)]
|
||||
struct MLXFastAttentionParams {
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
|
||||
ldq: i32, // ldq == ldo
|
||||
ldk: i32,
|
||||
ldv: i32,
|
||||
lds: i32,
|
||||
ldo: i32,
|
||||
|
||||
tiles_n: i32,
|
||||
tiles_m: i32,
|
||||
|
||||
batch_stride_q: i32,
|
||||
batch_stride_k: i32,
|
||||
batch_stride_v: i32,
|
||||
batch_stride_o: i32,
|
||||
|
||||
swizzle_log: i32,
|
||||
gemm_n_iterations_aligned: i32,
|
||||
gemm_k_iterations_aligned: i32,
|
||||
gemm_sv_m_block_iterations: i32,
|
||||
|
||||
batch_ndim: i32,
|
||||
alpha: f32,
|
||||
softcapping: f32,
|
||||
}
|
||||
|
||||
let bk = q_shape.last().unwrap();
|
||||
|
||||
const BN: usize = 16;
|
||||
const BM: usize = 16;
|
||||
const WM: usize = 2;
|
||||
const WN: usize = 2;
|
||||
|
||||
let name = match (bk, itype) {
|
||||
(32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half",
|
||||
(64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half",
|
||||
(96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half",
|
||||
(128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half",
|
||||
(256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half",
|
||||
(32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float",
|
||||
(64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float",
|
||||
(96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float",
|
||||
(128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float",
|
||||
(256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float",
|
||||
(other, SdpaDType::F16 | SdpaDType::F32) => {
|
||||
return Err(MetalKernelError::SdpaHeadSizeMismatch {
|
||||
variation: "full",
|
||||
got: *other,
|
||||
expected: vec![32, 64, 96, 128, 256],
|
||||
})
|
||||
}
|
||||
(_, SdpaDType::BF16) => {
|
||||
return Err(MetalKernelError::SdpaHeadDTypeMismatch {
|
||||
variation: "full",
|
||||
got: SdpaDType::BF16,
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
// q = (bs, qhead, seq, hidden)
|
||||
// k/v = (bs, kv_head, seq, hidden)
|
||||
|
||||
let qseq = q_shape[q_shape.len() - 2];
|
||||
|
||||
let m = q_shape[q_shape.len() - 2];
|
||||
let n = m;
|
||||
let k = q_shape[q_shape.len() - 1];
|
||||
let bs_out = q_shape[0] * q_shape[1];
|
||||
|
||||
let batch_shape = [q_shape[0] * q_shape[1]];
|
||||
let dk = q_shape[q_shape.len() - 1];
|
||||
let ldq = dk;
|
||||
let ldk = dk;
|
||||
let ldv = dk;
|
||||
let lds = BN;
|
||||
let ldo = dk;
|
||||
|
||||
let tn = 1;
|
||||
let tm = (m + BM - 1) / BM;
|
||||
|
||||
let b_stride_q = dk * qseq;
|
||||
let b_stride_k = dk * qseq;
|
||||
let b_stride_v = dk * qseq;
|
||||
let b_stride_o = dk * qseq;
|
||||
let swizzle_log = 0;
|
||||
let gemm_n_iterations_aligned = (n + BN - 1) / BN;
|
||||
let gemm_k_iterations_aligned = (k + bk - 1) / bk;
|
||||
let gemm_sv_m_block_iterations = (m + BM - 1) / BM;
|
||||
let batch_ndim = batch_shape.len();
|
||||
|
||||
let alpha = if softcapping != 1. {
|
||||
alpha / softcapping
|
||||
} else {
|
||||
alpha
|
||||
};
|
||||
|
||||
let params = MLXFastAttentionParams {
|
||||
m: m as i32,
|
||||
n: n as i32,
|
||||
k: k as i32,
|
||||
ldq: ldq as i32,
|
||||
ldk: ldk as i32,
|
||||
ldv: ldv as i32,
|
||||
lds: lds as i32,
|
||||
ldo: ldo as i32,
|
||||
tiles_n: tn,
|
||||
tiles_m: tm as i32,
|
||||
batch_stride_q: b_stride_q as i32,
|
||||
batch_stride_k: b_stride_k as i32,
|
||||
batch_stride_v: b_stride_v as i32,
|
||||
batch_stride_o: b_stride_o as i32,
|
||||
swizzle_log,
|
||||
gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32,
|
||||
gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32,
|
||||
gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32,
|
||||
batch_ndim: batch_ndim as i32,
|
||||
alpha,
|
||||
softcapping,
|
||||
};
|
||||
let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o];
|
||||
|
||||
impl EncoderParam for MLXFastAttentionParams {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
core::mem::size_of::<MLXFastAttentionParams>() as u64,
|
||||
&data as *const MLXFastAttentionParams as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
(q_buffer, q_offset),
|
||||
(k_buffer, k_offset),
|
||||
(v_buffer, v_offset),
|
||||
output,
|
||||
params,
|
||||
&batch_shape[..],
|
||||
&batch_strides[..]
|
||||
)
|
||||
);
|
||||
|
||||
let grid_dims = MTLSize {
|
||||
width: 1,
|
||||
height: tm as u64,
|
||||
depth: bs_out as u64,
|
||||
};
|
||||
let group_dims = MTLSize {
|
||||
width: 32,
|
||||
height: WM as u64,
|
||||
depth: WN as u64,
|
||||
};
|
||||
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_dims, group_dims);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// SDPA full is supported when:
|
||||
/// - q head dim == 64, 96, 128
|
||||
/// - no mask
|
||||
/// - q,k,v are contiguous
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_sdpa_vector(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
q_offset: usize,
|
||||
q_shape: &[usize],
|
||||
q_buffer: &Buffer,
|
||||
k_offset: usize,
|
||||
k_shape: &[usize],
|
||||
k_stride: &[usize],
|
||||
k_buffer: &Buffer,
|
||||
v_offset: usize,
|
||||
v_stride: &[usize],
|
||||
v_buffer: &Buffer,
|
||||
output: &Buffer,
|
||||
alpha: f32,
|
||||
softcapping: f32,
|
||||
itype: SdpaDType,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let bk = q_shape.last().unwrap();
|
||||
|
||||
let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
|
||||
let n = k_shape[2] as i32;
|
||||
let b = (q_shape[0] * q_shape[1]) as i32;
|
||||
let kstride = k_stride[1];
|
||||
let vstride = v_stride[1];
|
||||
|
||||
let name = match (bk, itype) {
|
||||
(32, SdpaDType::F16) => "sdpa_vector_float16_t_32",
|
||||
(64, SdpaDType::F16) => "sdpa_vector_float16_t_64",
|
||||
(96, SdpaDType::F16) => "sdpa_vector_float16_t_96",
|
||||
(128, SdpaDType::F16) => "sdpa_vector_float16_t_128",
|
||||
(256, SdpaDType::F16) => "sdpa_vector_float16_t_256",
|
||||
(32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32",
|
||||
(64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64",
|
||||
(96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96",
|
||||
(128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128",
|
||||
(256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256",
|
||||
(32, SdpaDType::F32) => "sdpa_vector_float_32",
|
||||
(64, SdpaDType::F32) => "sdpa_vector_float_64",
|
||||
(96, SdpaDType::F32) => "sdpa_vector_float_96",
|
||||
(128, SdpaDType::F32) => "sdpa_vector_float_128",
|
||||
(256, SdpaDType::F32) => "sdpa_vector_float_256",
|
||||
(other, _) => {
|
||||
return Err(MetalKernelError::SdpaHeadSizeMismatch {
|
||||
variation: "vector",
|
||||
got: *other,
|
||||
expected: vec![32, 64, 96, 128, 256],
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let alpha = if softcapping != 1. {
|
||||
alpha / softcapping
|
||||
} else {
|
||||
alpha
|
||||
};
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
// q = (bs, qhead, seq, hidden)
|
||||
// k/v = (bs, kv_head, kv_seq, hidden)
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
(q_buffer, q_offset),
|
||||
(k_buffer, k_offset),
|
||||
(v_buffer, v_offset),
|
||||
output,
|
||||
gqa_factor,
|
||||
n,
|
||||
kstride,
|
||||
vstride,
|
||||
alpha,
|
||||
softcapping
|
||||
)
|
||||
);
|
||||
|
||||
let grid_dims = MTLSize {
|
||||
width: 1,
|
||||
height: b as u64,
|
||||
depth: 1 as u64,
|
||||
};
|
||||
let group_dims = MTLSize {
|
||||
width: 1024,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_dims, group_dims);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_im2col1d_strided(
|
||||
device: &Device,
|
||||
|
1257
candle-metal-kernels/src/scaled_dot_product_attention.metal
Normal file
1257
candle-metal-kernels/src/scaled_dot_product_attention.metal
Normal file
File diff suppressed because it is too large
Load Diff
@ -964,3 +964,193 @@ impl Module for Identity {
|
||||
Ok(xs.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct Sdpa {
|
||||
scale: f32,
|
||||
softcapping: f32,
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for Sdpa {
|
||||
fn name(&self) -> &'static str {
|
||||
"metal-sdpa"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
_s1: &CpuStorage,
|
||||
_l1: &Layout,
|
||||
_s2: &CpuStorage,
|
||||
_l2: &Layout,
|
||||
_s3: &CpuStorage,
|
||||
_l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
candle::bail!("SDPA has no cpu impl")
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
q: &candle::MetalStorage,
|
||||
q_l: &Layout,
|
||||
k: &candle::MetalStorage,
|
||||
k_l: &Layout,
|
||||
v: &candle::MetalStorage,
|
||||
v_l: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle_metal_kernels::SdpaDType;
|
||||
|
||||
let device = q.device();
|
||||
|
||||
let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
|
||||
let elem_count: usize = out_dims.iter().product();
|
||||
|
||||
let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
|
||||
|
||||
// q,k must have matching emb dim
|
||||
if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
|
||||
candle::bail!("`q` and `k` last dims must match");
|
||||
}
|
||||
|
||||
// k,v must have matching n kv heads
|
||||
if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
|
||||
candle::bail!("`k` and `v` head dims must match");
|
||||
}
|
||||
|
||||
// n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
|
||||
if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
|
||||
candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
|
||||
}
|
||||
|
||||
let k_head = k_l.dim(D::Minus1)?;
|
||||
let q_head = q_l.dim(D::Minus1)?;
|
||||
let q_seq = q_l.dim(2)?;
|
||||
|
||||
let mut implementation_supports_use_case = q_head == k_head;
|
||||
let supported_head_dim =
|
||||
q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256;
|
||||
|
||||
const SDPA_FULL_THRESHOLD: usize = 2;
|
||||
|
||||
let supports_sdpa_full =
|
||||
q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head;
|
||||
let supports_sdpa_vector = q_seq == 1 && supported_head_dim;
|
||||
|
||||
implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
if !supported_head_dim {
|
||||
candle::bail!(
|
||||
"Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
|
||||
q_l.dims(),
|
||||
k_l.dims(),
|
||||
v_l.dims()
|
||||
);
|
||||
}
|
||||
if !implementation_supports_use_case {
|
||||
candle::bail!(
|
||||
"Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
|
||||
q_l.dims(),
|
||||
k_l.dims(),
|
||||
v_l.dims()
|
||||
);
|
||||
}
|
||||
|
||||
for t in [k.dtype(), v.dtype()] {
|
||||
if q.dtype() != t {
|
||||
candle::bail!("all q, k, v dtypes must match.");
|
||||
}
|
||||
}
|
||||
|
||||
let itype = match q.dtype() {
|
||||
DType::BF16 => SdpaDType::BF16,
|
||||
DType::F16 => SdpaDType::F16,
|
||||
DType::F32 => SdpaDType::F32,
|
||||
other => candle::bail!("unsupported sdpa type {other:?}"),
|
||||
};
|
||||
|
||||
let command_buffer = q.device().command_buffer()?;
|
||||
if supports_sdpa_vector {
|
||||
command_buffer.set_label("vector_attention");
|
||||
candle_metal_kernels::call_sdpa_vector(
|
||||
q.device().device(),
|
||||
&command_buffer,
|
||||
q.device().kernels(),
|
||||
q_l.start_offset(),
|
||||
q_l.dims(),
|
||||
q.buffer(),
|
||||
k_l.start_offset(),
|
||||
k_l.dims(),
|
||||
k_l.stride(),
|
||||
k.buffer(),
|
||||
v_l.start_offset(),
|
||||
v_l.stride(),
|
||||
v.buffer(),
|
||||
&output,
|
||||
self.scale,
|
||||
self.softcapping,
|
||||
itype,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
} else if supports_sdpa_full {
|
||||
if q_l.dim(2)? != k_l.dim(2)? {
|
||||
candle::bail!(
|
||||
"query and key sequence length must be equal if using full metal sdpa"
|
||||
)
|
||||
}
|
||||
|
||||
command_buffer.set_label("full_attention");
|
||||
candle_metal_kernels::call_sdpa_full(
|
||||
q.device().device(),
|
||||
&command_buffer,
|
||||
q.device().kernels(),
|
||||
q_l.start_offset(),
|
||||
q_l.dims(),
|
||||
q.buffer(),
|
||||
k_l.start_offset(),
|
||||
k.buffer(),
|
||||
v_l.start_offset(),
|
||||
v.buffer(),
|
||||
&output,
|
||||
self.scale,
|
||||
self.softcapping,
|
||||
itype,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
} else {
|
||||
candle::bail!("must be vector or full sdpa kernel");
|
||||
}
|
||||
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
|
||||
Ok((newstorage, Shape::from_dims(&out_dims)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Scaled dot product attention with a fused kernel.
|
||||
///
|
||||
/// Computes softmax(qk^T*scale)v.
|
||||
///
|
||||
/// **Inputs shapes:**
|
||||
/// - `q`: (bs, qhead, seq, hidden)
|
||||
/// - `k`: (bs, kv_head, kv_seq, hidden)
|
||||
/// - `k`: (bs, kv_head, kv_seq, v_hidden)
|
||||
/// - `scale` is applied before softmax.
|
||||
/// - If `softcapping` != 1.0:
|
||||
/// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
|
||||
///
|
||||
/// **Output shape:** (bs, qhead, seq, v_hidden)
|
||||
///
|
||||
/// **Supported head dims:** 32, 64, 96, 128, 256.
|
||||
///
|
||||
/// ## On Metal:
|
||||
/// - If `seq` == 1:
|
||||
/// - Use a vectorized kernel
|
||||
/// - Supports `seq` != `kv_seq` (cross attn. support)
|
||||
/// - Supports GQA when `qhead` is a multiple of `kv_head`
|
||||
/// - Otherwise:
|
||||
/// - Use an alternate kernel
|
||||
/// - Requires `seq` == `kv_seq`
|
||||
/// - GQA is not supported (requires `qhead` == `kv_head`)
|
||||
pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result<Tensor> {
|
||||
q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping })
|
||||
}
|
||||
|
206
candle-nn/tests/sdpa.rs
Normal file
206
candle-nn/tests/sdpa.rs
Normal file
@ -0,0 +1,206 @@
|
||||
#[cfg(feature = "metal")]
|
||||
mod metal_sdpa_tests {
|
||||
#[test]
|
||||
fn sdpa_full() -> candle::Result<()> {
|
||||
use candle::{DType, Device, Tensor};
|
||||
|
||||
// Force seqlen = 100
|
||||
const BS: usize = 4;
|
||||
const R: usize = 4;
|
||||
const L: usize = 4;
|
||||
const DK: usize = 64;
|
||||
const H: usize = 3;
|
||||
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
||||
|
||||
let device = Device::new_metal(0)?;
|
||||
|
||||
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
|
||||
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
||||
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
||||
|
||||
let ground_truth = {
|
||||
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
||||
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
|
||||
.to_dtype(q.dtype())?;
|
||||
att.matmul(&v.clone())?
|
||||
};
|
||||
|
||||
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
|
||||
|
||||
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
||||
|
||||
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
||||
.sum_all()?
|
||||
.to_scalar()?;
|
||||
|
||||
assert!(error <= 0.0005, "{}", error);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sdpa_vector() -> candle::Result<()> {
|
||||
use candle::{DType, Device, Tensor};
|
||||
|
||||
// Allow vectorized, seqlen = 1
|
||||
const BS: usize = 4;
|
||||
const R: usize = 1;
|
||||
const L: usize = 1;
|
||||
const DK: usize = 64;
|
||||
const H: usize = 3;
|
||||
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
||||
|
||||
let device = Device::new_metal(0)?;
|
||||
|
||||
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
|
||||
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
||||
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
||||
|
||||
let ground_truth = {
|
||||
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
||||
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
|
||||
.to_dtype(q.dtype())?;
|
||||
att.matmul(&v.clone())?
|
||||
};
|
||||
|
||||
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
|
||||
|
||||
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
||||
|
||||
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
||||
.sum_all()?
|
||||
.to_scalar()?;
|
||||
|
||||
assert!(error <= 0.0001, "{}", error);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sdpa_full_softcapping() -> candle::Result<()> {
|
||||
use candle::{DType, Device, Tensor};
|
||||
use std::ops::{Div, Mul};
|
||||
|
||||
// Allow vectorized, seqlen = 1
|
||||
const BS: usize = 4;
|
||||
const R: usize = 4;
|
||||
const L: usize = 4;
|
||||
const DK: usize = 64;
|
||||
const H: usize = 3;
|
||||
const SOFTCAP: f64 = 50.;
|
||||
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
||||
|
||||
let device = Device::new_metal(0)?;
|
||||
|
||||
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
|
||||
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
||||
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
||||
|
||||
let ground_truth = {
|
||||
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
||||
let att = candle_nn::ops::softmax_last_dim(
|
||||
&att.to_dtype(DType::F32)?
|
||||
.div(SOFTCAP)?
|
||||
.tanh()?
|
||||
.mul(SOFTCAP)?,
|
||||
)?
|
||||
.to_dtype(q.dtype())?;
|
||||
att.matmul(&v.clone())?
|
||||
};
|
||||
|
||||
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
|
||||
|
||||
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
||||
|
||||
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
||||
.sum_all()?
|
||||
.to_scalar()?;
|
||||
|
||||
assert!(error <= 0.0004, "{}", error);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sdpa_vector_softcapping() -> candle::Result<()> {
|
||||
use candle::{DType, Device, Tensor};
|
||||
use std::ops::{Div, Mul};
|
||||
|
||||
// Allow vectorized, seqlen = 1
|
||||
const BS: usize = 4;
|
||||
const R: usize = 1;
|
||||
const L: usize = 1;
|
||||
const DK: usize = 64;
|
||||
const H: usize = 3;
|
||||
const SOFTCAP: f64 = 50.;
|
||||
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
||||
|
||||
let device = Device::new_metal(0)?;
|
||||
|
||||
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
|
||||
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
||||
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
||||
|
||||
let ground_truth = {
|
||||
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
||||
let att = candle_nn::ops::softmax_last_dim(
|
||||
&att.to_dtype(DType::F32)?
|
||||
.div(SOFTCAP)?
|
||||
.tanh()?
|
||||
.mul(SOFTCAP)?,
|
||||
)?
|
||||
.to_dtype(q.dtype())?;
|
||||
att.matmul(&v.clone())?
|
||||
};
|
||||
|
||||
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
|
||||
|
||||
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
||||
|
||||
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
||||
.sum_all()?
|
||||
.to_scalar()?;
|
||||
|
||||
assert!(error <= 0.0001, "{}", error);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sdpa_vector_cross() -> candle::Result<()> {
|
||||
use candle::{DType, Device, Tensor};
|
||||
|
||||
// Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1
|
||||
const BS: usize = 4;
|
||||
const R: usize = 1;
|
||||
const L: usize = 24;
|
||||
const DK: usize = 64;
|
||||
const H: usize = 3;
|
||||
let scale: f64 = f64::from(DK as u32).sqrt().recip();
|
||||
|
||||
let device = Device::new_metal(0)?;
|
||||
|
||||
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
|
||||
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
||||
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
|
||||
|
||||
let ground_truth = {
|
||||
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
|
||||
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
|
||||
.to_dtype(q.dtype())?;
|
||||
att.matmul(&v.clone())?
|
||||
};
|
||||
|
||||
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
|
||||
|
||||
assert_eq!(ground_truth.shape(), sdpa_output.shape());
|
||||
|
||||
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
|
||||
.sum_all()?
|
||||
.to_scalar()?;
|
||||
|
||||
assert!(error <= 0.0013, "{}", error);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -205,21 +205,27 @@ impl LayerWeights {
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
// Support for MQA, useful for 70B models and mistral.
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
let y = if q.device().is_metal() && seq_len == 1 {
|
||||
// SDPA will do MQA for us
|
||||
candle_nn::ops::sdpa(&q, &k, &v, 1. / (self.head_dim as f32).sqrt(), 1.)?
|
||||
} else {
|
||||
// Support for MQA, useful for 70B models and mistral.
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let att = match mask {
|
||||
None => att,
|
||||
Some(mask) => {
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, &self.neg_inf)?
|
||||
}
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let att = match mask {
|
||||
None => att,
|
||||
Some(mask) => {
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, &self.neg_inf)?
|
||||
}
|
||||
};
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
att.matmul(&v.contiguous()?)?
|
||||
};
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.attention_wo.forward(&y)?;
|
||||
Ok(y)
|
||||
|
Reference in New Issue
Block a user