mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Sync upstream MLX sdpa vector kernels with mask (#2718)
* Sync upstream mlx sdpa vector kernels with mask * Dispatch to the 2pass kernel * Format
This commit is contained in:
@ -1906,7 +1906,12 @@ pub fn call_sdpa_vector(
|
|||||||
alpha
|
alpha
|
||||||
};
|
};
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
|
let constants = Some(ConstantValues::new(vec![(
|
||||||
|
20,
|
||||||
|
Value::Bool(/* sdpa_vector_has_mask */ false),
|
||||||
|
)]));
|
||||||
|
|
||||||
|
let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?;
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -1948,6 +1953,187 @@ pub fn call_sdpa_vector(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const SDPA_2PASS_BLOCKS: usize = 32;
|
||||||
|
|
||||||
|
/// SDPA vector 2pass is supported when:
|
||||||
|
/// - q head dim == 64, 96, 128
|
||||||
|
/// - no mask
|
||||||
|
/// - q,k,v are contiguous
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_sdpa_vector_2pass(
|
||||||
|
device: &Device,
|
||||||
|
ep: impl EncoderProvider,
|
||||||
|
kernels: &Kernels,
|
||||||
|
q_offset: usize,
|
||||||
|
q_shape: &[usize],
|
||||||
|
q_buffer: &Buffer,
|
||||||
|
k_offset: usize,
|
||||||
|
k_shape: &[usize],
|
||||||
|
k_stride: &[usize],
|
||||||
|
k_buffer: &Buffer,
|
||||||
|
v_offset: usize,
|
||||||
|
v_stride: &[usize],
|
||||||
|
v_buffer: &Buffer,
|
||||||
|
output: &Buffer,
|
||||||
|
intermediate: &Buffer,
|
||||||
|
sums: &Buffer,
|
||||||
|
maxs: &Buffer,
|
||||||
|
alpha: f32,
|
||||||
|
softcapping: f32,
|
||||||
|
itype: SdpaDType,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let bk = q_shape.last().unwrap();
|
||||||
|
|
||||||
|
// First pass
|
||||||
|
{
|
||||||
|
let name_pass1 = match (bk, itype) {
|
||||||
|
(32, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_32",
|
||||||
|
(64, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_64",
|
||||||
|
(96, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_96",
|
||||||
|
(128, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_128",
|
||||||
|
(256, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_256",
|
||||||
|
(32, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_32",
|
||||||
|
(64, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_64",
|
||||||
|
(96, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_96",
|
||||||
|
(128, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_128",
|
||||||
|
(256, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_256",
|
||||||
|
(32, SdpaDType::F32) => "sdpa_vector_2pass_1_float_32",
|
||||||
|
(64, SdpaDType::F32) => "sdpa_vector_2pass_1_float_64",
|
||||||
|
(96, SdpaDType::F32) => "sdpa_vector_2pass_1_float_96",
|
||||||
|
(128, SdpaDType::F32) => "sdpa_vector_2pass_1_float_128",
|
||||||
|
(256, SdpaDType::F32) => "sdpa_vector_2pass_1_float_256",
|
||||||
|
(other, _) => {
|
||||||
|
return Err(MetalKernelError::SdpaHeadSizeMismatch {
|
||||||
|
variation: "vector_2pass_1",
|
||||||
|
got: *other,
|
||||||
|
expected: vec![32, 64, 96, 128, 256],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
|
||||||
|
let n = k_shape[2] as i32;
|
||||||
|
let b = (q_shape[0] * q_shape[1]) as i32;
|
||||||
|
let kstride = k_stride[1];
|
||||||
|
let vstride = v_stride[1];
|
||||||
|
|
||||||
|
let alpha = if softcapping != 1. {
|
||||||
|
alpha / softcapping
|
||||||
|
} else {
|
||||||
|
alpha
|
||||||
|
};
|
||||||
|
|
||||||
|
let constants = Some(ConstantValues::new(vec![(
|
||||||
|
20,
|
||||||
|
Value::Bool(/* sdpa_vector_has_mask */ false),
|
||||||
|
)]));
|
||||||
|
|
||||||
|
let pipeline =
|
||||||
|
kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?;
|
||||||
|
let encoder = ep.encoder();
|
||||||
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
// q = (bs, qhead, seq, hidden)
|
||||||
|
// k/v = (bs, kv_head, kv_seq, hidden)
|
||||||
|
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
(q_buffer, q_offset),
|
||||||
|
(k_buffer, k_offset),
|
||||||
|
(v_buffer, v_offset),
|
||||||
|
intermediate,
|
||||||
|
sums,
|
||||||
|
maxs,
|
||||||
|
gqa_factor,
|
||||||
|
n,
|
||||||
|
kstride,
|
||||||
|
vstride,
|
||||||
|
alpha,
|
||||||
|
softcapping
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
let grid_dims = MTLSize {
|
||||||
|
width: 1,
|
||||||
|
height: b as u64,
|
||||||
|
depth: SDPA_2PASS_BLOCKS as u64,
|
||||||
|
};
|
||||||
|
let group_dims = MTLSize {
|
||||||
|
width: 8 * 32,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(intermediate, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.use_resource(sums, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.use_resource(maxs, metal::MTLResourceUsage::Write);
|
||||||
|
|
||||||
|
encoder.dispatch_thread_groups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final pass
|
||||||
|
{
|
||||||
|
let name_pass2 = match (bk, itype) {
|
||||||
|
(32, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_32",
|
||||||
|
(64, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_64",
|
||||||
|
(96, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_96",
|
||||||
|
(128, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_128",
|
||||||
|
(256, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_256",
|
||||||
|
(32, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_32",
|
||||||
|
(64, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_64",
|
||||||
|
(96, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_96",
|
||||||
|
(128, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_128",
|
||||||
|
(256, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_256",
|
||||||
|
(32, SdpaDType::F32) => "sdpa_vector_2pass_2_float_32",
|
||||||
|
(64, SdpaDType::F32) => "sdpa_vector_2pass_2_float_64",
|
||||||
|
(96, SdpaDType::F32) => "sdpa_vector_2pass_2_float_96",
|
||||||
|
(128, SdpaDType::F32) => "sdpa_vector_2pass_2_float_128",
|
||||||
|
(256, SdpaDType::F32) => "sdpa_vector_2pass_2_float_256",
|
||||||
|
(other, _) => {
|
||||||
|
return Err(MetalKernelError::SdpaHeadSizeMismatch {
|
||||||
|
variation: "vector_2pass_2",
|
||||||
|
got: *other,
|
||||||
|
expected: vec![32, 64, 96, 128, 256],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let b = (q_shape[0] * q_shape[1]) as i32;
|
||||||
|
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass2)?;
|
||||||
|
let encoder = ep.encoder();
|
||||||
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
// q = (bs, qhead, seq, hidden)
|
||||||
|
// k/v = (bs, kv_head, kv_seq, hidden)
|
||||||
|
|
||||||
|
set_params!(encoder, (intermediate, sums, maxs, output));
|
||||||
|
|
||||||
|
let grid_dims = MTLSize {
|
||||||
|
width: 1,
|
||||||
|
height: b as u64,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
let group_dims = MTLSize {
|
||||||
|
width: 1024,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
encoder.use_resource(intermediate, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.use_resource(sums, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.use_resource(maxs, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
|
||||||
|
encoder.dispatch_thread_groups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_im2col1d_strided(
|
pub fn call_im2col1d_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
|
@ -47,6 +47,8 @@ struct MLXScaledDotProductAttentionParams {
|
|||||||
|
|
||||||
// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector"
|
// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector"
|
||||||
|
|
||||||
|
constant bool sdpa_vector_has_mask [[function_constant(20)]];
|
||||||
|
|
||||||
template <typename T, int D>
|
template <typename T, int D>
|
||||||
[[kernel]] void sdpa_vector(
|
[[kernel]] void sdpa_vector(
|
||||||
const device T* queries [[buffer(0)]],
|
const device T* queries [[buffer(0)]],
|
||||||
@ -59,14 +61,16 @@ template <typename T, int D>
|
|||||||
const constant size_t& v_stride,
|
const constant size_t& v_stride,
|
||||||
const constant float& scale,
|
const constant float& scale,
|
||||||
const constant float& softcapping,
|
const constant float& softcapping,
|
||||||
|
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],
|
||||||
|
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]],
|
||||||
|
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int BN = 32;
|
constexpr int BN = 32;
|
||||||
constexpr int BD = 32;
|
constexpr int BD = 32;
|
||||||
constexpr int elem_per_thread = D / BD;
|
constexpr int elem_per_thread = D / BD;
|
||||||
|
constexpr int stride = BN * D;
|
||||||
const int stride = BN * D;
|
|
||||||
|
|
||||||
typedef float U;
|
typedef float U;
|
||||||
|
|
||||||
@ -84,6 +88,9 @@ template <typename T, int D>
|
|||||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||||
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
|
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||||
|
if (sdpa_vector_has_mask) {
|
||||||
|
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
|
||||||
|
}
|
||||||
out += head_idx * D + simd_gid * elem_per_thread;
|
out += head_idx * D + simd_gid * elem_per_thread;
|
||||||
|
|
||||||
// Read the query and 0 the output accumulator
|
// Read the query and 0 the output accumulator
|
||||||
@ -99,6 +106,138 @@ template <typename T, int D>
|
|||||||
|
|
||||||
// For each key
|
// For each key
|
||||||
for (int i = simd_gid; i < N; i += BN) {
|
for (int i = simd_gid; i < N; i += BN) {
|
||||||
|
if (!sdpa_vector_has_mask || mask[0]) {
|
||||||
|
// Read the key
|
||||||
|
for (int j = 0; j < elem_per_thread; j++) {
|
||||||
|
k[j] = keys[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the i-th score
|
||||||
|
U score = 0;
|
||||||
|
for (int j = 0; j < elem_per_thread; j++) {
|
||||||
|
score += q[j] * k[j];
|
||||||
|
}
|
||||||
|
score = simd_sum(score);
|
||||||
|
if (softcapping != 1.) {
|
||||||
|
score = precise::tanh(score);
|
||||||
|
score = score * softcapping;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the accumulators
|
||||||
|
U new_max = max(max_score, score);
|
||||||
|
U factor = fast::exp(max_score - new_max);
|
||||||
|
U exp_score = fast::exp(score - new_max);
|
||||||
|
|
||||||
|
max_score = new_max;
|
||||||
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|
||||||
|
// Update the output accumulator
|
||||||
|
for (int j = 0; j < elem_per_thread; j++) {
|
||||||
|
o[j] = o[j] * factor + exp_score * values[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the pointers to the next kv
|
||||||
|
keys += stride;
|
||||||
|
values += stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each thread has a partial part of the output so we need to combine them.
|
||||||
|
|
||||||
|
// First let's communicate the max and sum_exp
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
max_scores[simd_gid] = max_score;
|
||||||
|
sum_exp_scores[simd_gid] = sum_exp_score;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
max_score = max_scores[simd_lid];
|
||||||
|
U new_max = simd_max(max_score);
|
||||||
|
U factor = fast::exp(max_score - new_max);
|
||||||
|
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
outputs[simd_lid * BD + simd_gid] = o[i];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
// And write the output
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
out[i] = static_cast<T>(o[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int D>
|
||||||
|
[[kernel]] void sdpa_vector_2pass_1(
|
||||||
|
const device T* queries [[buffer(0)]],
|
||||||
|
const device T* keys [[buffer(1)]],
|
||||||
|
const device T* values [[buffer(2)]],
|
||||||
|
device float* out [[buffer(3)]],
|
||||||
|
device float* sums [[buffer(4)]],
|
||||||
|
device float* maxs [[buffer(5)]],
|
||||||
|
const constant int& gqa_factor,
|
||||||
|
const constant int& N,
|
||||||
|
const constant size_t& k_stride,
|
||||||
|
const constant size_t& v_stride,
|
||||||
|
const constant float& scale,
|
||||||
|
const constant float& softcapping,
|
||||||
|
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],
|
||||||
|
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]],
|
||||||
|
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
constexpr int BN = 8;
|
||||||
|
constexpr int BD = 32;
|
||||||
|
constexpr int elem_per_thread = D / BD;
|
||||||
|
constexpr int stride = BN * D;
|
||||||
|
constexpr int blocks = 32;
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
thread U q[elem_per_thread];
|
||||||
|
thread U k[elem_per_thread];
|
||||||
|
thread U o[elem_per_thread];
|
||||||
|
|
||||||
|
threadgroup U outputs[BN * BD];
|
||||||
|
threadgroup U max_scores[BN];
|
||||||
|
threadgroup U sum_exp_scores[BN];
|
||||||
|
|
||||||
|
// Adjust positions
|
||||||
|
const int block_idx = tid.z;
|
||||||
|
const int head_idx = tid.y;
|
||||||
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
|
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||||
|
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
|
||||||
|
simd_lid * elem_per_thread;
|
||||||
|
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
|
||||||
|
simd_lid * elem_per_thread;
|
||||||
|
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
|
||||||
|
if (sdpa_vector_has_mask) {
|
||||||
|
mask += head_idx * mask_head_stride +
|
||||||
|
(block_idx * BN + simd_gid) * mask_seq_stride;
|
||||||
|
}
|
||||||
|
sums += head_idx * blocks + block_idx;
|
||||||
|
maxs += head_idx * blocks + block_idx;
|
||||||
|
|
||||||
|
// Read the query and 0 the output accumulator
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
q[i] = static_cast<U>(scale) * queries[i];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
o[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
U max_score = -1e9;
|
||||||
|
U sum_exp_score = 0;
|
||||||
|
|
||||||
|
// For each key
|
||||||
|
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||||
|
if (!sdpa_vector_has_mask || mask[0]) {
|
||||||
// Read the key
|
// Read the key
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
k[i] = keys[i];
|
k[i] = keys[i];
|
||||||
@ -127,27 +266,54 @@ template <typename T, int D>
|
|||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
o[i] = o[i] * factor + exp_score * values[i];
|
o[i] = o[i] * factor + exp_score * values[i];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Move the pointers to the next kv
|
// Move the pointers to the next kv
|
||||||
keys += stride;
|
keys += blocks * stride;
|
||||||
values += stride;
|
values += blocks * stride;
|
||||||
|
if (sdpa_vector_has_mask) {
|
||||||
|
mask += BN * blocks * mask_seq_stride;
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Each thread has a partial part of the output so we need to combine them.
|
|
||||||
|
|
||||||
// First let's communicate the max and sum_exp
|
|
||||||
if (simd_lid == 0) {
|
|
||||||
max_scores[simd_gid] = max_score;
|
|
||||||
sum_exp_scores[simd_gid] = sum_exp_score;
|
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
}
|
||||||
max_score = max_scores[simd_lid];
|
|
||||||
|
template <typename T, int D>
|
||||||
|
[[kernel]] void sdpa_vector_2pass_2(
|
||||||
|
const device float* partials [[buffer(0)]],
|
||||||
|
const device float* sums [[buffer(1)]],
|
||||||
|
const device float* maxs [[buffer(2)]],
|
||||||
|
device T* out [[buffer(3)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
constexpr int BN = 32;
|
||||||
|
constexpr int BD = 32;
|
||||||
|
constexpr int elem_per_thread = D / BD;
|
||||||
|
constexpr int blocks = 32;
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
thread U o[elem_per_thread];
|
||||||
|
threadgroup U outputs[BN * BD];
|
||||||
|
|
||||||
|
// Adjust positions
|
||||||
|
const int head_idx = tid.y;
|
||||||
|
partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
||||||
|
sums += head_idx * blocks;
|
||||||
|
maxs += head_idx * blocks;
|
||||||
|
out += head_idx * D + simd_gid * elem_per_thread;
|
||||||
|
|
||||||
|
// First everybody reads the max and sum_exp
|
||||||
|
U max_score = maxs[simd_lid];
|
||||||
U new_max = simd_max(max_score);
|
U new_max = simd_max(max_score);
|
||||||
U factor = fast::exp(max_score - new_max);
|
U factor = fast::exp(max_score - new_max);
|
||||||
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
|
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
|
||||||
|
|
||||||
// Now we need to aggregate all the outputs
|
// Now read the block into registers and then use shared memory to transpose
|
||||||
|
// it
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
o[i] = partials[i];
|
||||||
|
}
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
outputs[simd_lid * BD + simd_gid] = o[i];
|
outputs[simd_lid * BD + simd_gid] = o[i];
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
@ -1238,9 +1404,41 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2);
|
|||||||
const constant size_t& v_stride, \
|
const constant size_t& v_stride, \
|
||||||
const constant float& scale, \
|
const constant float& scale, \
|
||||||
const constant float& softcapping, \
|
const constant float& softcapping, \
|
||||||
|
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
|
||||||
|
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
|
||||||
|
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]); \
|
||||||
|
template [[host_name("sdpa_vector_2pass_1_" #type "_" #head_dim)]] \
|
||||||
|
[[kernel]] void sdpa_vector_2pass_1<type, head_dim>( \
|
||||||
|
const device type* queries [[buffer(0)]], \
|
||||||
|
const device type* keys [[buffer(1)]], \
|
||||||
|
const device type* values [[buffer(2)]], \
|
||||||
|
device float* out [[buffer(3)]], \
|
||||||
|
device float* sums [[buffer(4)]], \
|
||||||
|
device float* maxs [[buffer(5)]], \
|
||||||
|
const constant int& gqa_factor, \
|
||||||
|
const constant int& N, \
|
||||||
|
const constant size_t& k_stride, \
|
||||||
|
const constant size_t& v_stride, \
|
||||||
|
const constant float& scale, \
|
||||||
|
const constant float& softcapping, \
|
||||||
|
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
|
||||||
|
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
|
||||||
|
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]); \
|
||||||
|
template [[host_name("sdpa_vector_2pass_2_" #type "_" #head_dim)]] \
|
||||||
|
[[kernel]] void sdpa_vector_2pass_2<type, head_dim>( \
|
||||||
|
const device float* partials [[buffer(0)]], \
|
||||||
|
const device float* sums [[buffer(1)]], \
|
||||||
|
const device float* maxs [[buffer(2)]], \
|
||||||
|
device type* out [[buffer(3)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]); \
|
||||||
|
|
||||||
#define instantiate_sdpa_vector_heads(type) \
|
#define instantiate_sdpa_vector_heads(type) \
|
||||||
instantiate_sdpa_vector(type, 32) \
|
instantiate_sdpa_vector(type, 32) \
|
||||||
|
@ -1074,6 +1074,58 @@ impl candle::CustomOp3 for Sdpa {
|
|||||||
|
|
||||||
let command_buffer = q.device().command_buffer()?;
|
let command_buffer = q.device().command_buffer()?;
|
||||||
if supports_sdpa_vector {
|
if supports_sdpa_vector {
|
||||||
|
// Route to the 2 pass fused attention if the k seqlen is large.
|
||||||
|
// https://github.com/ml-explore/mlx/pull/1597
|
||||||
|
const TWO_PASS_K_THRESHOLD: usize = 1024;
|
||||||
|
if k_l.dim(2)? >= TWO_PASS_K_THRESHOLD {
|
||||||
|
let mut intermediate_shape = [
|
||||||
|
&out_dims[0..out_dims.len() - 2],
|
||||||
|
&[candle_metal_kernels::SDPA_2PASS_BLOCKS],
|
||||||
|
&[out_dims[out_dims.len() - 1]],
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
let intermediate = device.new_buffer(
|
||||||
|
intermediate_shape.iter().product::<usize>(),
|
||||||
|
DType::F32,
|
||||||
|
"sdpa_2pass_intermediate",
|
||||||
|
)?;
|
||||||
|
let _ = intermediate_shape.pop().unwrap();
|
||||||
|
let sums = device.new_buffer(
|
||||||
|
intermediate_shape.iter().product::<usize>(),
|
||||||
|
DType::F32,
|
||||||
|
"sdpa_2pass_sums",
|
||||||
|
)?;
|
||||||
|
let maxs = device.new_buffer(
|
||||||
|
intermediate_shape.iter().product::<usize>(),
|
||||||
|
DType::F32,
|
||||||
|
"sdpa_2pass_maxs",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
command_buffer.set_label("vector_attention");
|
||||||
|
candle_metal_kernels::call_sdpa_vector_2pass(
|
||||||
|
q.device().device(),
|
||||||
|
&command_buffer,
|
||||||
|
q.device().kernels(),
|
||||||
|
q_l.start_offset(),
|
||||||
|
q_l.dims(),
|
||||||
|
q.buffer(),
|
||||||
|
k_l.start_offset(),
|
||||||
|
k_l.dims(),
|
||||||
|
k_l.stride(),
|
||||||
|
k.buffer(),
|
||||||
|
v_l.start_offset(),
|
||||||
|
v_l.stride(),
|
||||||
|
v.buffer(),
|
||||||
|
&output,
|
||||||
|
&intermediate,
|
||||||
|
&sums,
|
||||||
|
&maxs,
|
||||||
|
self.scale,
|
||||||
|
self.softcapping,
|
||||||
|
itype,
|
||||||
|
)
|
||||||
|
.map_err(candle::Error::wrap)?;
|
||||||
|
} else {
|
||||||
command_buffer.set_label("vector_attention");
|
command_buffer.set_label("vector_attention");
|
||||||
candle_metal_kernels::call_sdpa_vector(
|
candle_metal_kernels::call_sdpa_vector(
|
||||||
q.device().device(),
|
q.device().device(),
|
||||||
@ -1095,6 +1147,7 @@ impl candle::CustomOp3 for Sdpa {
|
|||||||
itype,
|
itype,
|
||||||
)
|
)
|
||||||
.map_err(candle::Error::wrap)?;
|
.map_err(candle::Error::wrap)?;
|
||||||
|
}
|
||||||
} else if supports_sdpa_full {
|
} else if supports_sdpa_full {
|
||||||
if q_l.dim(2)? != k_l.dim(2)? {
|
if q_l.dim(2)? != k_l.dim(2)? {
|
||||||
candle::bail!(
|
candle::bail!(
|
||||||
|
Reference in New Issue
Block a user