mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Remove the old MFA gemm kernels. (#2742)
* Remove the old MFA gemm kernels. * Use bf16 in helium on metal.
This commit is contained in:
@ -44,66 +44,46 @@ fn run_gemm(f32: bool, n: usize) -> Result<()> {
|
||||
);
|
||||
(lhs, rhs)
|
||||
};
|
||||
let (dtype, name, sizeof) = if f32 {
|
||||
(GemmDType::F32, "sgemm", core::mem::size_of::<f32>())
|
||||
let (dtype, sizeof) = if f32 {
|
||||
(GemmDType::F32, core::mem::size_of::<f32>())
|
||||
} else {
|
||||
(GemmDType::F16, "hgemm", core::mem::size_of::<f16>())
|
||||
(GemmDType::F16, core::mem::size_of::<f16>())
|
||||
};
|
||||
let output = device.new_buffer((b * m * n * sizeof) as u64, options);
|
||||
|
||||
for mlx in [false, true] {
|
||||
let mut sum_dt = 0f64;
|
||||
let mut iters = 0usize;
|
||||
for idx in 0.. {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start_time = std::time::Instant::now();
|
||||
if mlx {
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&lhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&output,
|
||||
)?;
|
||||
} else {
|
||||
candle_metal_kernels::call_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
(b, m, n, k),
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&lhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&output,
|
||||
)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
let dt = start_time.elapsed().as_secs_f64();
|
||||
if idx < WARMUP_ITERS {
|
||||
continue;
|
||||
}
|
||||
sum_dt += dt;
|
||||
iters += 1;
|
||||
if sum_dt > MIN_DUR {
|
||||
break;
|
||||
}
|
||||
let mut sum_dt = 0f64;
|
||||
let mut iters = 0usize;
|
||||
for idx in 0.. {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start_time = std::time::Instant::now();
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&lhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&output,
|
||||
)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
let dt = start_time.elapsed().as_secs_f64();
|
||||
if idx < WARMUP_ITERS {
|
||||
continue;
|
||||
}
|
||||
sum_dt += dt;
|
||||
iters += 1;
|
||||
if sum_dt > MIN_DUR {
|
||||
break;
|
||||
}
|
||||
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
|
||||
let mlx = if mlx { "MLX" } else { "MFA" };
|
||||
println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}");
|
||||
}
|
||||
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
|
||||
println!("{dtype:?}, {n:6} gflops {gflops:.0}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user