Remove the old MFA gemm kernels. (#2742)

* Remove the old MFA gemm kernels.

* Use bf16 in helium on metal.
This commit is contained in:
Laurent Mazare
2025-01-26 20:36:31 +01:00
committed by GitHub
parent 1a32107fab
commit 27996a1a9e
6 changed files with 41 additions and 492 deletions

View File

@ -44,66 +44,46 @@ fn run_gemm(f32: bool, n: usize) -> Result<()> {
);
(lhs, rhs)
};
let (dtype, name, sizeof) = if f32 {
(GemmDType::F32, "sgemm", core::mem::size_of::<f32>())
let (dtype, sizeof) = if f32 {
(GemmDType::F32, core::mem::size_of::<f32>())
} else {
(GemmDType::F16, "hgemm", core::mem::size_of::<f16>())
(GemmDType::F16, core::mem::size_of::<f16>())
};
let output = device.new_buffer((b * m * n * sizeof) as u64, options);
for mlx in [false, true] {
let mut sum_dt = 0f64;
let mut iters = 0usize;
for idx in 0.. {
let command_buffer = command_queue.new_command_buffer();
let start_time = std::time::Instant::now();
if mlx {
candle_metal_kernels::call_mlx_gemm(
&device,
command_buffer,
&kernels,
dtype,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
} else {
candle_metal_kernels::call_gemm(
&device,
command_buffer,
&kernels,
name,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
}
command_buffer.commit();
command_buffer.wait_until_completed();
let dt = start_time.elapsed().as_secs_f64();
if idx < WARMUP_ITERS {
continue;
}
sum_dt += dt;
iters += 1;
if sum_dt > MIN_DUR {
break;
}
let mut sum_dt = 0f64;
let mut iters = 0usize;
for idx in 0.. {
let command_buffer = command_queue.new_command_buffer();
let start_time = std::time::Instant::now();
candle_metal_kernels::call_mlx_gemm(
&device,
command_buffer,
&kernels,
dtype,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
command_buffer.commit();
command_buffer.wait_until_completed();
let dt = start_time.elapsed().as_secs_f64();
if idx < WARMUP_ITERS {
continue;
}
sum_dt += dt;
iters += 1;
if sum_dt > MIN_DUR {
break;
}
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
let mlx = if mlx { "MLX" } else { "MFA" };
println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}");
}
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
println!("{dtype:?}, {n:6} gflops {gflops:.0}");
Ok(())
}