mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Remove the old MFA gemm kernels. (#2742)
* Remove the old MFA gemm kernels. * Use bf16 in helium on metal.
This commit is contained in:
@ -121,8 +121,6 @@ pub struct MetalDevice {
|
||||
pub(crate) kernels: Arc<Kernels>,
|
||||
/// Seed for random number generation.
|
||||
pub(crate) seed: Arc<Mutex<Buffer>>,
|
||||
/// Whether to use the MLX matmul kernels instead of the MFA ones.
|
||||
pub(crate) use_mlx_mm: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
@ -140,10 +138,6 @@ impl std::ops::Deref for MetalDevice {
|
||||
}
|
||||
|
||||
impl MetalDevice {
|
||||
pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) {
|
||||
self.use_mlx_mm = use_mlx_mm
|
||||
}
|
||||
|
||||
pub fn compile(
|
||||
&self,
|
||||
func_name: &'static str,
|
||||
|
@ -1469,7 +1469,7 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else if self.device.use_mlx_mm {
|
||||
} else {
|
||||
let dtype = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||
@ -1496,32 +1496,6 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "sgemm",
|
||||
DType::F16 => "hgemm",
|
||||
dtype => {
|
||||
return Err(
|
||||
MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
candle_metal_kernels::call_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
Ok(Self::new(
|
||||
buffer,
|
||||
@ -1884,10 +1858,6 @@ impl BackendDevice for MetalDevice {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
let command_queue = device.new_command_queue();
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
let use_mlx_mm = match std::env::var("CANDLE_USE_MFA_MM").as_deref() {
|
||||
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => true,
|
||||
Ok(_) => false,
|
||||
};
|
||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||
[299792458].as_ptr() as *const c_void,
|
||||
4,
|
||||
@ -1901,7 +1871,6 @@ impl BackendDevice for MetalDevice {
|
||||
buffers: Arc::new(RwLock::new(HashMap::new())),
|
||||
kernels,
|
||||
seed,
|
||||
use_mlx_mm,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -263,11 +263,7 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = {
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
(model, device)
|
||||
|
@ -44,66 +44,46 @@ fn run_gemm(f32: bool, n: usize) -> Result<()> {
|
||||
);
|
||||
(lhs, rhs)
|
||||
};
|
||||
let (dtype, name, sizeof) = if f32 {
|
||||
(GemmDType::F32, "sgemm", core::mem::size_of::<f32>())
|
||||
let (dtype, sizeof) = if f32 {
|
||||
(GemmDType::F32, core::mem::size_of::<f32>())
|
||||
} else {
|
||||
(GemmDType::F16, "hgemm", core::mem::size_of::<f16>())
|
||||
(GemmDType::F16, core::mem::size_of::<f16>())
|
||||
};
|
||||
let output = device.new_buffer((b * m * n * sizeof) as u64, options);
|
||||
|
||||
for mlx in [false, true] {
|
||||
let mut sum_dt = 0f64;
|
||||
let mut iters = 0usize;
|
||||
for idx in 0.. {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start_time = std::time::Instant::now();
|
||||
if mlx {
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&lhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&output,
|
||||
)?;
|
||||
} else {
|
||||
candle_metal_kernels::call_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
(b, m, n, k),
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&lhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&output,
|
||||
)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
let dt = start_time.elapsed().as_secs_f64();
|
||||
if idx < WARMUP_ITERS {
|
||||
continue;
|
||||
}
|
||||
sum_dt += dt;
|
||||
iters += 1;
|
||||
if sum_dt > MIN_DUR {
|
||||
break;
|
||||
}
|
||||
let mut sum_dt = 0f64;
|
||||
let mut iters = 0usize;
|
||||
for idx in 0.. {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start_time = std::time::Instant::now();
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&lhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&output,
|
||||
)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
let dt = start_time.elapsed().as_secs_f64();
|
||||
if idx < WARMUP_ITERS {
|
||||
continue;
|
||||
}
|
||||
sum_dt += dt;
|
||||
iters += 1;
|
||||
if sum_dt > MIN_DUR {
|
||||
break;
|
||||
}
|
||||
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
|
||||
let mlx = if mlx { "MLX" } else { "MFA" };
|
||||
println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}");
|
||||
}
|
||||
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
|
||||
println!("{dtype:?}, {n:6} gflops {gflops:.0}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -16,8 +16,6 @@ const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const FILL: &str = include_str!("fill.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
@ -36,7 +34,6 @@ pub enum Source {
|
||||
Fill,
|
||||
Gemm,
|
||||
Indexing,
|
||||
Mfa,
|
||||
Quantized,
|
||||
Random,
|
||||
Reduce,
|
||||
@ -221,7 +218,6 @@ impl Kernels {
|
||||
Source::Ternary => TERNARY,
|
||||
Source::Unary => UNARY,
|
||||
Source::Sdpa => SDPA,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -236,21 +232,11 @@ impl Kernels {
|
||||
if let Some(lib) = libraries.get(&source) {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let lib = match source {
|
||||
Source::Mfa => {
|
||||
let source_data = MFA;
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
}
|
||||
let lib = {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
};
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
@ -1471,176 +1457,6 @@ impl ConstantValues {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_gemm(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
lhs_offset: usize,
|
||||
lhs_buffer: &Buffer,
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
rhs_buffer: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
assert!(rhs_stride.len() >= 2);
|
||||
assert!(lhs_stride.len() >= 2);
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// lhs has shape b, m, k
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
false
|
||||
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
||||
true
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
// rhs has shape b, k, n
|
||||
let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
false
|
||||
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
||||
true
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
let d_trans = false;
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
let batched = b > 1;
|
||||
let fused_activation = false;
|
||||
let fused_bias = false;
|
||||
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||
let m_simd = 8;
|
||||
let n_simd = 8;
|
||||
let k_simd = 64;
|
||||
let m_splits = 1;
|
||||
let n_splits = 1;
|
||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
} else {
|
||||
let m_simd = 40;
|
||||
let n_simd = 40;
|
||||
let k_simd = 32;
|
||||
let m_splits = 1;
|
||||
let n_splits = 1;
|
||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
};
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(0, Value::USize(m)),
|
||||
(1, Value::USize(n)),
|
||||
(2, Value::USize(k)),
|
||||
(10, Value::Bool(a_trans)),
|
||||
(11, Value::Bool(b_trans)),
|
||||
(13, Value::Bool(d_trans)),
|
||||
(20, Value::F32(alpha)),
|
||||
(21, Value::F32(beta)),
|
||||
(100, Value::Bool(batched)),
|
||||
(101, Value::Bool(fused_activation)),
|
||||
// Garbage
|
||||
(102, Value::Bool(false)),
|
||||
(103, Value::Bool(false)),
|
||||
(113, Value::Bool(false)),
|
||||
(50_000, Value::Bool(false)),
|
||||
// End garbage
|
||||
(200, Value::U16(m_simd)),
|
||||
(201, Value::U16(n_simd)),
|
||||
(202, Value::U16(k_simd)),
|
||||
(210, Value::U16(m_splits)),
|
||||
(211, Value::U16(n_splits)),
|
||||
(50_001, Value::Bool(fused_bias)),
|
||||
]));
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
||||
let m_group = m_simd * m_splits;
|
||||
let n_group = n_simd * n_splits;
|
||||
|
||||
let a_block_length = m_group * k_simd;
|
||||
let b_block_length = k_simd * n_group;
|
||||
|
||||
let mut block_elements = a_block_length + b_block_length;
|
||||
if (m % 8 != 0) && (n % 8 != 0) {
|
||||
let c_block_length = m_group * n_group;
|
||||
block_elements = std::cmp::max(c_block_length, block_elements)
|
||||
}
|
||||
if fused_bias {
|
||||
if d_trans {
|
||||
block_elements = std::cmp::max(block_elements, m_group);
|
||||
} else {
|
||||
block_elements = std::cmp::max(block_elements, n_group);
|
||||
}
|
||||
}
|
||||
let bytes = match name {
|
||||
"sgemm" => 4,
|
||||
"hgemm" => 2,
|
||||
"bgemm" => 2,
|
||||
other => {
|
||||
return Err(MetalKernelError::LoadLibraryError(format!(
|
||||
"{other} is not a valid kernel for gemm"
|
||||
)));
|
||||
}
|
||||
};
|
||||
let block_bytes = block_elements * bytes;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
encoder.set_buffer(2, Some(output), 0);
|
||||
// TODO Tensor D
|
||||
|
||||
let grid_z = b;
|
||||
if batched {
|
||||
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
||||
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
||||
let byte_stride_c = m * n * bytes as usize;
|
||||
// TODO byte_stride_d
|
||||
let byte_stride_d = 0;
|
||||
|
||||
let buffer: Vec<u64> = vec![
|
||||
byte_stride_a as _,
|
||||
byte_stride_b as _,
|
||||
byte_stride_c as _,
|
||||
byte_stride_d as _,
|
||||
];
|
||||
encoder.set_bytes(
|
||||
10,
|
||||
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||
);
|
||||
}
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: divide(n, n_group.into()),
|
||||
height: divide(m, m_group.into()),
|
||||
depth: grid_z as NSUInteger,
|
||||
};
|
||||
let group_size = MTLSize {
|
||||
width: 32 * (m_splits as u64) * (n_splits as u64),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum SdpaDType {
|
||||
BF16,
|
||||
|
@ -1046,168 +1046,6 @@ fn where_cond_u32_f32() {
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_gemm<T: Clone>(
|
||||
name: &'static str,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &[T],
|
||||
lhs_stride: &[usize],
|
||||
lhs_offset: usize,
|
||||
rhs: &[T],
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let lhs = device.new_buffer_with_data(
|
||||
lhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(lhs) as u64,
|
||||
options,
|
||||
);
|
||||
let rhs = device.new_buffer_with_data(
|
||||
rhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(rhs) as u64,
|
||||
options,
|
||||
);
|
||||
let length = b * m * n;
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
(b, m, n, k),
|
||||
lhs_stride,
|
||||
lhs_offset,
|
||||
&lhs,
|
||||
rhs_stride,
|
||||
rhs_offset,
|
||||
&rhs,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&output, length)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemm() {
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_gemm(
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
|
||||
let (b, m, n, k) = (2, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_gemm(
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![
|
||||
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
|
||||
518.0, 548.0, 578.0
|
||||
]
|
||||
);
|
||||
|
||||
// OFFSET
|
||||
let (b, m, n, k) = (2, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
|
||||
let results = run_gemm(
|
||||
"sgemm",
|
||||
(1, m, n, k),
|
||||
&lhs,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
&rhs_stride,
|
||||
12 * 4,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
||||
);
|
||||
|
||||
// bgemm sanity test
|
||||
if false {
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let results = run_gemm(
|
||||
"bgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
}
|
||||
|
||||
// hgemm sanity test
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||
let results = run_gemm(
|
||||
"hgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_f16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_mlx_gemm<T: Clone>(
|
||||
dtype: GemmDType,
|
||||
@ -1258,50 +1096,6 @@ fn run_mlx_gemm<T: Clone>(
|
||||
read_to_vec(&output, length)
|
||||
}
|
||||
|
||||
fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) {
|
||||
use rand::SeedableRng;
|
||||
use rand_distr::Distribution;
|
||||
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(42424242);
|
||||
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
|
||||
|
||||
let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect();
|
||||
let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect();
|
||||
let v1: Vec<f32> = run_mlx_gemm(
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[k * n, n, 1],
|
||||
0,
|
||||
);
|
||||
let v2: Vec<f32> = run_gemm(
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[k * n, n, 1],
|
||||
0,
|
||||
);
|
||||
for (a, b) in v1.iter().zip(v2.iter()) {
|
||||
let diff = (a - b).abs();
|
||||
assert_eq!((diff * 1e4).round(), 0.)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mlx_vs_mfa() {
|
||||
mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32);
|
||||
mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32);
|
||||
mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32);
|
||||
mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32);
|
||||
mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mlx_gemm() {
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
|
Reference in New Issue
Block a user