Remove the old MFA gemm kernels. (#2742)

* Remove the old MFA gemm kernels.

* Use bf16 in helium on metal.
This commit is contained in:
Laurent Mazare
2025-01-26 20:36:31 +01:00
committed by GitHub
parent 1a32107fab
commit 27996a1a9e
6 changed files with 41 additions and 492 deletions

View File

@ -16,8 +16,6 @@ const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const FILL: &str = include_str!("fill.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
const QUANTIZED: &str = include_str!("quantized.metal");
const RANDOM: &str = include_str!("random.metal");
@ -36,7 +34,6 @@ pub enum Source {
Fill,
Gemm,
Indexing,
Mfa,
Quantized,
Random,
Reduce,
@ -221,7 +218,6 @@ impl Kernels {
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Sdpa => SDPA,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -236,21 +232,11 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
let lib = match source {
Source::Mfa => {
let source_data = MFA;
device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?
}
source => {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
}
let lib = {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
};
libraries.insert(source, lib.clone());
Ok(lib)
@ -1471,176 +1457,6 @@ impl ConstantValues {
}
}
#[allow(clippy::too_many_arguments)]
pub fn call_gemm(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
lhs_offset: usize,
lhs_buffer: &Buffer,
rhs_stride: &[usize],
rhs_offset: usize,
rhs_buffer: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
assert!(rhs_stride.len() >= 2);
assert!(lhs_stride.len() >= 2);
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
false
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
// rhs has shape b, k, n
let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
false
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
let d_trans = false;
let alpha = 1.0f32;
let beta = 0.0f32;
let batched = b > 1;
let fused_activation = false;
let fused_bias = false;
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
let m_simd = 8;
let n_simd = 8;
let k_simd = 64;
let m_splits = 1;
let n_splits = 1;
(m_simd, n_simd, k_simd, m_splits, n_splits)
} else {
let m_simd = 40;
let n_simd = 40;
let k_simd = 32;
let m_splits = 1;
let n_splits = 1;
(m_simd, n_simd, k_simd, m_splits, n_splits)
};
let constants = Some(ConstantValues::new(vec![
(0, Value::USize(m)),
(1, Value::USize(n)),
(2, Value::USize(k)),
(10, Value::Bool(a_trans)),
(11, Value::Bool(b_trans)),
(13, Value::Bool(d_trans)),
(20, Value::F32(alpha)),
(21, Value::F32(beta)),
(100, Value::Bool(batched)),
(101, Value::Bool(fused_activation)),
// Garbage
(102, Value::Bool(false)),
(103, Value::Bool(false)),
(113, Value::Bool(false)),
(50_000, Value::Bool(false)),
// End garbage
(200, Value::U16(m_simd)),
(201, Value::U16(n_simd)),
(202, Value::U16(k_simd)),
(210, Value::U16(m_splits)),
(211, Value::U16(n_splits)),
(50_001, Value::Bool(fused_bias)),
]));
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
let m_group = m_simd * m_splits;
let n_group = n_simd * n_splits;
let a_block_length = m_group * k_simd;
let b_block_length = k_simd * n_group;
let mut block_elements = a_block_length + b_block_length;
if (m % 8 != 0) && (n % 8 != 0) {
let c_block_length = m_group * n_group;
block_elements = std::cmp::max(c_block_length, block_elements)
}
if fused_bias {
if d_trans {
block_elements = std::cmp::max(block_elements, m_group);
} else {
block_elements = std::cmp::max(block_elements, n_group);
}
}
let bytes = match name {
"sgemm" => 4,
"hgemm" => 2,
"bgemm" => 2,
other => {
return Err(MetalKernelError::LoadLibraryError(format!(
"{other} is not a valid kernel for gemm"
)));
}
};
let block_bytes = block_elements * bytes;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(2, Some(output), 0);
// TODO Tensor D
let grid_z = b;
if batched {
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
let byte_stride_c = m * n * bytes as usize;
// TODO byte_stride_d
let byte_stride_d = 0;
let buffer: Vec<u64> = vec![
byte_stride_a as _,
byte_stride_b as _,
byte_stride_c as _,
byte_stride_d as _,
];
encoder.set_bytes(
10,
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
buffer.as_ptr() as *const NSUInteger as *const c_void,
);
}
let grid_size = MTLSize {
width: divide(n, n_group.into()),
height: divide(m, m_group.into()),
depth: grid_z as NSUInteger,
};
let group_size = MTLSize {
width: 32 * (m_splits as u64) * (n_splits as u64),
height: 1,
depth: 1,
};
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
Ok(())
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum SdpaDType {
BF16,