mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Sync upstream MLX sdpa vector kernels with mask (#2718)
* Sync upstream mlx sdpa vector kernels with mask * Dispatch to the 2pass kernel * Format
This commit is contained in:
@ -1906,7 +1906,12 @@ pub fn call_sdpa_vector(
|
||||
alpha
|
||||
};
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
|
||||
let constants = Some(ConstantValues::new(vec![(
|
||||
20,
|
||||
Value::Bool(/* sdpa_vector_has_mask */ false),
|
||||
)]));
|
||||
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -1948,6 +1953,187 @@ pub fn call_sdpa_vector(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub const SDPA_2PASS_BLOCKS: usize = 32;
|
||||
|
||||
/// SDPA vector 2pass is supported when:
|
||||
/// - q head dim == 64, 96, 128
|
||||
/// - no mask
|
||||
/// - q,k,v are contiguous
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_sdpa_vector_2pass(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
q_offset: usize,
|
||||
q_shape: &[usize],
|
||||
q_buffer: &Buffer,
|
||||
k_offset: usize,
|
||||
k_shape: &[usize],
|
||||
k_stride: &[usize],
|
||||
k_buffer: &Buffer,
|
||||
v_offset: usize,
|
||||
v_stride: &[usize],
|
||||
v_buffer: &Buffer,
|
||||
output: &Buffer,
|
||||
intermediate: &Buffer,
|
||||
sums: &Buffer,
|
||||
maxs: &Buffer,
|
||||
alpha: f32,
|
||||
softcapping: f32,
|
||||
itype: SdpaDType,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let bk = q_shape.last().unwrap();
|
||||
|
||||
// First pass
|
||||
{
|
||||
let name_pass1 = match (bk, itype) {
|
||||
(32, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_32",
|
||||
(64, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_64",
|
||||
(96, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_96",
|
||||
(128, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_128",
|
||||
(256, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_256",
|
||||
(32, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_32",
|
||||
(64, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_64",
|
||||
(96, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_96",
|
||||
(128, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_128",
|
||||
(256, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_256",
|
||||
(32, SdpaDType::F32) => "sdpa_vector_2pass_1_float_32",
|
||||
(64, SdpaDType::F32) => "sdpa_vector_2pass_1_float_64",
|
||||
(96, SdpaDType::F32) => "sdpa_vector_2pass_1_float_96",
|
||||
(128, SdpaDType::F32) => "sdpa_vector_2pass_1_float_128",
|
||||
(256, SdpaDType::F32) => "sdpa_vector_2pass_1_float_256",
|
||||
(other, _) => {
|
||||
return Err(MetalKernelError::SdpaHeadSizeMismatch {
|
||||
variation: "vector_2pass_1",
|
||||
got: *other,
|
||||
expected: vec![32, 64, 96, 128, 256],
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
|
||||
let n = k_shape[2] as i32;
|
||||
let b = (q_shape[0] * q_shape[1]) as i32;
|
||||
let kstride = k_stride[1];
|
||||
let vstride = v_stride[1];
|
||||
|
||||
let alpha = if softcapping != 1. {
|
||||
alpha / softcapping
|
||||
} else {
|
||||
alpha
|
||||
};
|
||||
|
||||
let constants = Some(ConstantValues::new(vec![(
|
||||
20,
|
||||
Value::Bool(/* sdpa_vector_has_mask */ false),
|
||||
)]));
|
||||
|
||||
let pipeline =
|
||||
kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
// q = (bs, qhead, seq, hidden)
|
||||
// k/v = (bs, kv_head, kv_seq, hidden)
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
(q_buffer, q_offset),
|
||||
(k_buffer, k_offset),
|
||||
(v_buffer, v_offset),
|
||||
intermediate,
|
||||
sums,
|
||||
maxs,
|
||||
gqa_factor,
|
||||
n,
|
||||
kstride,
|
||||
vstride,
|
||||
alpha,
|
||||
softcapping
|
||||
)
|
||||
);
|
||||
|
||||
let grid_dims = MTLSize {
|
||||
width: 1,
|
||||
height: b as u64,
|
||||
depth: SDPA_2PASS_BLOCKS as u64,
|
||||
};
|
||||
let group_dims = MTLSize {
|
||||
width: 8 * 32,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(intermediate, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(sums, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(maxs, metal::MTLResourceUsage::Write);
|
||||
|
||||
encoder.dispatch_thread_groups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Final pass
|
||||
{
|
||||
let name_pass2 = match (bk, itype) {
|
||||
(32, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_32",
|
||||
(64, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_64",
|
||||
(96, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_96",
|
||||
(128, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_128",
|
||||
(256, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_256",
|
||||
(32, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_32",
|
||||
(64, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_64",
|
||||
(96, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_96",
|
||||
(128, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_128",
|
||||
(256, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_256",
|
||||
(32, SdpaDType::F32) => "sdpa_vector_2pass_2_float_32",
|
||||
(64, SdpaDType::F32) => "sdpa_vector_2pass_2_float_64",
|
||||
(96, SdpaDType::F32) => "sdpa_vector_2pass_2_float_96",
|
||||
(128, SdpaDType::F32) => "sdpa_vector_2pass_2_float_128",
|
||||
(256, SdpaDType::F32) => "sdpa_vector_2pass_2_float_256",
|
||||
(other, _) => {
|
||||
return Err(MetalKernelError::SdpaHeadSizeMismatch {
|
||||
variation: "vector_2pass_2",
|
||||
got: *other,
|
||||
expected: vec![32, 64, 96, 128, 256],
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let b = (q_shape[0] * q_shape[1]) as i32;
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass2)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
// q = (bs, qhead, seq, hidden)
|
||||
// k/v = (bs, kv_head, kv_seq, hidden)
|
||||
|
||||
set_params!(encoder, (intermediate, sums, maxs, output));
|
||||
|
||||
let grid_dims = MTLSize {
|
||||
width: 1,
|
||||
height: b as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let group_dims = MTLSize {
|
||||
width: 1024,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
encoder.use_resource(intermediate, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(sums, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(maxs, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
|
||||
encoder.dispatch_thread_groups(grid_dims, group_dims);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_im2col1d_strided(
|
||||
device: &Device,
|
||||
|
Reference in New Issue
Block a user