Remove the old MFA gemm kernels. (#2742)

* Remove the old MFA gemm kernels.

* Use bf16 in helium on metal.
This commit is contained in:
Laurent Mazare
2025-01-26 20:36:31 +01:00
committed by GitHub
parent 1a32107fab
commit 27996a1a9e
6 changed files with 41 additions and 492 deletions

View File

@ -121,8 +121,6 @@ pub struct MetalDevice {
pub(crate) kernels: Arc<Kernels>,
/// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>,
/// Whether to use the MLX matmul kernels instead of the MFA ones.
pub(crate) use_mlx_mm: bool,
}
impl std::fmt::Debug for MetalDevice {
@ -140,10 +138,6 @@ impl std::ops::Deref for MetalDevice {
}
impl MetalDevice {
pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) {
self.use_mlx_mm = use_mlx_mm
}
pub fn compile(
&self,
func_name: &'static str,

View File

@ -1469,7 +1469,7 @@ impl BackendStorage for MetalStorage {
&buffer,
)
.map_err(MetalError::from)?;
} else if self.device.use_mlx_mm {
} else {
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
@ -1496,32 +1496,6 @@ impl BackendStorage for MetalStorage {
&buffer,
)
.map_err(MetalError::from)?;
} else {
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
dtype => {
return Err(
MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(),
)
}
};
candle_metal_kernels::call_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
Ok(Self::new(
buffer,
@ -1884,10 +1858,6 @@ impl BackendDevice for MetalDevice {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
let kernels = Arc::new(Kernels::new());
let use_mlx_mm = match std::env::var("CANDLE_USE_MFA_MM").as_deref() {
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => true,
Ok(_) => false,
};
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
@ -1901,7 +1871,6 @@ impl BackendDevice for MetalDevice {
buffers: Arc::new(RwLock::new(HashMap::new())),
kernels,
seed,
use_mlx_mm,
})
}

View File

@ -263,11 +263,7 @@ fn main() -> Result<()> {
};
let device = candle_examples::device(args.cpu)?;
let (model, device) = {
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
(model, device)

View File

@ -44,66 +44,46 @@ fn run_gemm(f32: bool, n: usize) -> Result<()> {
);
(lhs, rhs)
};
let (dtype, name, sizeof) = if f32 {
(GemmDType::F32, "sgemm", core::mem::size_of::<f32>())
let (dtype, sizeof) = if f32 {
(GemmDType::F32, core::mem::size_of::<f32>())
} else {
(GemmDType::F16, "hgemm", core::mem::size_of::<f16>())
(GemmDType::F16, core::mem::size_of::<f16>())
};
let output = device.new_buffer((b * m * n * sizeof) as u64, options);
for mlx in [false, true] {
let mut sum_dt = 0f64;
let mut iters = 0usize;
for idx in 0.. {
let command_buffer = command_queue.new_command_buffer();
let start_time = std::time::Instant::now();
if mlx {
candle_metal_kernels::call_mlx_gemm(
&device,
command_buffer,
&kernels,
dtype,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
} else {
candle_metal_kernels::call_gemm(
&device,
command_buffer,
&kernels,
name,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
}
command_buffer.commit();
command_buffer.wait_until_completed();
let dt = start_time.elapsed().as_secs_f64();
if idx < WARMUP_ITERS {
continue;
}
sum_dt += dt;
iters += 1;
if sum_dt > MIN_DUR {
break;
}
let mut sum_dt = 0f64;
let mut iters = 0usize;
for idx in 0.. {
let command_buffer = command_queue.new_command_buffer();
let start_time = std::time::Instant::now();
candle_metal_kernels::call_mlx_gemm(
&device,
command_buffer,
&kernels,
dtype,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
command_buffer.commit();
command_buffer.wait_until_completed();
let dt = start_time.elapsed().as_secs_f64();
if idx < WARMUP_ITERS {
continue;
}
sum_dt += dt;
iters += 1;
if sum_dt > MIN_DUR {
break;
}
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
let mlx = if mlx { "MLX" } else { "MFA" };
println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}");
}
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
println!("{dtype:?}, {n:6} gflops {gflops:.0}");
Ok(())
}

View File

@ -16,8 +16,6 @@ const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const FILL: &str = include_str!("fill.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
const QUANTIZED: &str = include_str!("quantized.metal");
const RANDOM: &str = include_str!("random.metal");
@ -36,7 +34,6 @@ pub enum Source {
Fill,
Gemm,
Indexing,
Mfa,
Quantized,
Random,
Reduce,
@ -221,7 +218,6 @@ impl Kernels {
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Sdpa => SDPA,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -236,21 +232,11 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
let lib = match source {
Source::Mfa => {
let source_data = MFA;
device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?
}
source => {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
}
let lib = {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
};
libraries.insert(source, lib.clone());
Ok(lib)
@ -1471,176 +1457,6 @@ impl ConstantValues {
}
}
#[allow(clippy::too_many_arguments)]
pub fn call_gemm(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
lhs_offset: usize,
lhs_buffer: &Buffer,
rhs_stride: &[usize],
rhs_offset: usize,
rhs_buffer: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
assert!(rhs_stride.len() >= 2);
assert!(lhs_stride.len() >= 2);
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
false
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
// rhs has shape b, k, n
let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
false
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
let d_trans = false;
let alpha = 1.0f32;
let beta = 0.0f32;
let batched = b > 1;
let fused_activation = false;
let fused_bias = false;
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
let m_simd = 8;
let n_simd = 8;
let k_simd = 64;
let m_splits = 1;
let n_splits = 1;
(m_simd, n_simd, k_simd, m_splits, n_splits)
} else {
let m_simd = 40;
let n_simd = 40;
let k_simd = 32;
let m_splits = 1;
let n_splits = 1;
(m_simd, n_simd, k_simd, m_splits, n_splits)
};
let constants = Some(ConstantValues::new(vec![
(0, Value::USize(m)),
(1, Value::USize(n)),
(2, Value::USize(k)),
(10, Value::Bool(a_trans)),
(11, Value::Bool(b_trans)),
(13, Value::Bool(d_trans)),
(20, Value::F32(alpha)),
(21, Value::F32(beta)),
(100, Value::Bool(batched)),
(101, Value::Bool(fused_activation)),
// Garbage
(102, Value::Bool(false)),
(103, Value::Bool(false)),
(113, Value::Bool(false)),
(50_000, Value::Bool(false)),
// End garbage
(200, Value::U16(m_simd)),
(201, Value::U16(n_simd)),
(202, Value::U16(k_simd)),
(210, Value::U16(m_splits)),
(211, Value::U16(n_splits)),
(50_001, Value::Bool(fused_bias)),
]));
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
let m_group = m_simd * m_splits;
let n_group = n_simd * n_splits;
let a_block_length = m_group * k_simd;
let b_block_length = k_simd * n_group;
let mut block_elements = a_block_length + b_block_length;
if (m % 8 != 0) && (n % 8 != 0) {
let c_block_length = m_group * n_group;
block_elements = std::cmp::max(c_block_length, block_elements)
}
if fused_bias {
if d_trans {
block_elements = std::cmp::max(block_elements, m_group);
} else {
block_elements = std::cmp::max(block_elements, n_group);
}
}
let bytes = match name {
"sgemm" => 4,
"hgemm" => 2,
"bgemm" => 2,
other => {
return Err(MetalKernelError::LoadLibraryError(format!(
"{other} is not a valid kernel for gemm"
)));
}
};
let block_bytes = block_elements * bytes;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(2, Some(output), 0);
// TODO Tensor D
let grid_z = b;
if batched {
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
let byte_stride_c = m * n * bytes as usize;
// TODO byte_stride_d
let byte_stride_d = 0;
let buffer: Vec<u64> = vec![
byte_stride_a as _,
byte_stride_b as _,
byte_stride_c as _,
byte_stride_d as _,
];
encoder.set_bytes(
10,
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
buffer.as_ptr() as *const NSUInteger as *const c_void,
);
}
let grid_size = MTLSize {
width: divide(n, n_group.into()),
height: divide(m, m_group.into()),
depth: grid_z as NSUInteger,
};
let group_size = MTLSize {
width: 32 * (m_splits as u64) * (n_splits as u64),
height: 1,
depth: 1,
};
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
Ok(())
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum SdpaDType {
BF16,

View File

@ -1046,168 +1046,6 @@ fn where_cond_u32_f32() {
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
#[allow(clippy::too_many_arguments)]
fn run_gemm<T: Clone>(
name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: &[usize],
lhs_offset: usize,
rhs: &[T],
rhs_stride: &[usize],
rhs_offset: usize,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(rhs) as u64,
options,
);
let length = b * m * n;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_gemm(
&device,
command_buffer,
&kernels,
name,
(b, m, n, k),
lhs_stride,
lhs_offset,
&lhs,
rhs_stride,
rhs_offset,
&rhs,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, length)
}
#[test]
fn gemm() {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm(
"sgemm",
(b, m, n, k),
&lhs,
&lhs_stride,
0,
&rhs,
&rhs_stride,
0,
);
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
let (b, m, n, k) = (2, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm(
"sgemm",
(b, m, n, k),
&lhs,
&lhs_stride,
0,
&rhs,
&rhs_stride,
0,
);
assert_eq!(
approx(results, 4),
vec![
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
518.0, 548.0, 578.0
]
);
// OFFSET
let (b, m, n, k) = (2, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
let results = run_gemm(
"sgemm",
(1, m, n, k),
&lhs,
&lhs_stride,
0,
&rhs,
&rhs_stride,
12 * 4,
);
assert_eq!(
approx(results, 4),
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
// bgemm sanity test
if false {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
let results = run_gemm(
"bgemm",
(b, m, n, k),
&lhs,
&lhs_stride,
0,
&rhs,
&rhs_stride,
0,
);
assert_eq!(
approx_bf16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
// hgemm sanity test
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
let results = run_gemm(
"hgemm",
(b, m, n, k),
&lhs,
&lhs_stride,
0,
&rhs,
&rhs_stride,
0,
);
assert_eq!(
approx_f16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
#[allow(clippy::too_many_arguments)]
fn run_mlx_gemm<T: Clone>(
dtype: GemmDType,
@ -1258,50 +1096,6 @@ fn run_mlx_gemm<T: Clone>(
read_to_vec(&output, length)
}
fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) {
use rand::SeedableRng;
use rand_distr::Distribution;
let mut rng = rand::rngs::StdRng::seed_from_u64(42424242);
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect();
let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect();
let v1: Vec<f32> = run_mlx_gemm(
dtype,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[k * n, n, 1],
0,
);
let v2: Vec<f32> = run_gemm(
"sgemm",
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[k * n, n, 1],
0,
);
for (a, b) in v1.iter().zip(v2.iter()) {
let diff = (a - b).abs();
assert_eq!((diff * 1e4).round(), 0.)
}
}
#[test]
fn mlx_vs_mfa() {
mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32);
mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32);
mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32);
mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32);
mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32);
}
#[test]
fn mlx_gemm() {
let (b, m, n, k) = (1, 2, 4, 3);