mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32) * Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format * Conditional compilation for bf16 * Use it in quantized llama * Some review comments * Use set_params! * Remove unused * Remove feature * Fix metal sdpa for v stride * Remove comma * Add the dim method to layout and shape. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -964,3 +964,193 @@ impl Module for Identity {
|
||||
Ok(xs.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct Sdpa {
|
||||
scale: f32,
|
||||
softcapping: f32,
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for Sdpa {
|
||||
fn name(&self) -> &'static str {
|
||||
"metal-sdpa"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
_s1: &CpuStorage,
|
||||
_l1: &Layout,
|
||||
_s2: &CpuStorage,
|
||||
_l2: &Layout,
|
||||
_s3: &CpuStorage,
|
||||
_l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
candle::bail!("SDPA has no cpu impl")
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
q: &candle::MetalStorage,
|
||||
q_l: &Layout,
|
||||
k: &candle::MetalStorage,
|
||||
k_l: &Layout,
|
||||
v: &candle::MetalStorage,
|
||||
v_l: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle_metal_kernels::SdpaDType;
|
||||
|
||||
let device = q.device();
|
||||
|
||||
let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
|
||||
let elem_count: usize = out_dims.iter().product();
|
||||
|
||||
let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
|
||||
|
||||
// q,k must have matching emb dim
|
||||
if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
|
||||
candle::bail!("`q` and `k` last dims must match");
|
||||
}
|
||||
|
||||
// k,v must have matching n kv heads
|
||||
if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
|
||||
candle::bail!("`k` and `v` head dims must match");
|
||||
}
|
||||
|
||||
// n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
|
||||
if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
|
||||
candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
|
||||
}
|
||||
|
||||
let k_head = k_l.dim(D::Minus1)?;
|
||||
let q_head = q_l.dim(D::Minus1)?;
|
||||
let q_seq = q_l.dim(2)?;
|
||||
|
||||
let mut implementation_supports_use_case = q_head == k_head;
|
||||
let supported_head_dim =
|
||||
q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256;
|
||||
|
||||
const SDPA_FULL_THRESHOLD: usize = 2;
|
||||
|
||||
let supports_sdpa_full =
|
||||
q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head;
|
||||
let supports_sdpa_vector = q_seq == 1 && supported_head_dim;
|
||||
|
||||
implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
if !supported_head_dim {
|
||||
candle::bail!(
|
||||
"Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
|
||||
q_l.dims(),
|
||||
k_l.dims(),
|
||||
v_l.dims()
|
||||
);
|
||||
}
|
||||
if !implementation_supports_use_case {
|
||||
candle::bail!(
|
||||
"Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
|
||||
q_l.dims(),
|
||||
k_l.dims(),
|
||||
v_l.dims()
|
||||
);
|
||||
}
|
||||
|
||||
for t in [k.dtype(), v.dtype()] {
|
||||
if q.dtype() != t {
|
||||
candle::bail!("all q, k, v dtypes must match.");
|
||||
}
|
||||
}
|
||||
|
||||
let itype = match q.dtype() {
|
||||
DType::BF16 => SdpaDType::BF16,
|
||||
DType::F16 => SdpaDType::F16,
|
||||
DType::F32 => SdpaDType::F32,
|
||||
other => candle::bail!("unsupported sdpa type {other:?}"),
|
||||
};
|
||||
|
||||
let command_buffer = q.device().command_buffer()?;
|
||||
if supports_sdpa_vector {
|
||||
command_buffer.set_label("vector_attention");
|
||||
candle_metal_kernels::call_sdpa_vector(
|
||||
q.device().device(),
|
||||
&command_buffer,
|
||||
q.device().kernels(),
|
||||
q_l.start_offset(),
|
||||
q_l.dims(),
|
||||
q.buffer(),
|
||||
k_l.start_offset(),
|
||||
k_l.dims(),
|
||||
k_l.stride(),
|
||||
k.buffer(),
|
||||
v_l.start_offset(),
|
||||
v_l.stride(),
|
||||
v.buffer(),
|
||||
&output,
|
||||
self.scale,
|
||||
self.softcapping,
|
||||
itype,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
} else if supports_sdpa_full {
|
||||
if q_l.dim(2)? != k_l.dim(2)? {
|
||||
candle::bail!(
|
||||
"query and key sequence length must be equal if using full metal sdpa"
|
||||
)
|
||||
}
|
||||
|
||||
command_buffer.set_label("full_attention");
|
||||
candle_metal_kernels::call_sdpa_full(
|
||||
q.device().device(),
|
||||
&command_buffer,
|
||||
q.device().kernels(),
|
||||
q_l.start_offset(),
|
||||
q_l.dims(),
|
||||
q.buffer(),
|
||||
k_l.start_offset(),
|
||||
k.buffer(),
|
||||
v_l.start_offset(),
|
||||
v.buffer(),
|
||||
&output,
|
||||
self.scale,
|
||||
self.softcapping,
|
||||
itype,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
} else {
|
||||
candle::bail!("must be vector or full sdpa kernel");
|
||||
}
|
||||
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
|
||||
Ok((newstorage, Shape::from_dims(&out_dims)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Scaled dot product attention with a fused kernel.
|
||||
///
|
||||
/// Computes softmax(qk^T*scale)v.
|
||||
///
|
||||
/// **Inputs shapes:**
|
||||
/// - `q`: (bs, qhead, seq, hidden)
|
||||
/// - `k`: (bs, kv_head, kv_seq, hidden)
|
||||
/// - `k`: (bs, kv_head, kv_seq, v_hidden)
|
||||
/// - `scale` is applied before softmax.
|
||||
/// - If `softcapping` != 1.0:
|
||||
/// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
|
||||
///
|
||||
/// **Output shape:** (bs, qhead, seq, v_hidden)
|
||||
///
|
||||
/// **Supported head dims:** 32, 64, 96, 128, 256.
|
||||
///
|
||||
/// ## On Metal:
|
||||
/// - If `seq` == 1:
|
||||
/// - Use a vectorized kernel
|
||||
/// - Supports `seq` != `kv_seq` (cross attn. support)
|
||||
/// - Supports GQA when `qhead` is a multiple of `kv_head`
|
||||
/// - Otherwise:
|
||||
/// - Use an alternate kernel
|
||||
/// - Requires `seq` == `kv_seq`
|
||||
/// - GQA is not supported (requires `qhead` == `kv_head`)
|
||||
pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result<Tensor> {
|
||||
q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping })
|
||||
}
|
||||
|
Reference in New Issue
Block a user