mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Remove the old MFA gemm kernels. (#2742)
* Remove the old MFA gemm kernels. * Use bf16 in helium on metal.
This commit is contained in:
@ -121,8 +121,6 @@ pub struct MetalDevice {
|
|||||||
pub(crate) kernels: Arc<Kernels>,
|
pub(crate) kernels: Arc<Kernels>,
|
||||||
/// Seed for random number generation.
|
/// Seed for random number generation.
|
||||||
pub(crate) seed: Arc<Mutex<Buffer>>,
|
pub(crate) seed: Arc<Mutex<Buffer>>,
|
||||||
/// Whether to use the MLX matmul kernels instead of the MFA ones.
|
|
||||||
pub(crate) use_mlx_mm: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
impl std::fmt::Debug for MetalDevice {
|
||||||
@ -140,10 +138,6 @@ impl std::ops::Deref for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MetalDevice {
|
impl MetalDevice {
|
||||||
pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) {
|
|
||||||
self.use_mlx_mm = use_mlx_mm
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn compile(
|
pub fn compile(
|
||||||
&self,
|
&self,
|
||||||
func_name: &'static str,
|
func_name: &'static str,
|
||||||
|
@ -1469,7 +1469,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else if self.device.use_mlx_mm {
|
} else {
|
||||||
let dtype = match self.dtype {
|
let dtype = match self.dtype {
|
||||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||||
@ -1496,32 +1496,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
|
||||||
let name = match self.dtype {
|
|
||||||
DType::F32 => "sgemm",
|
|
||||||
DType::F16 => "hgemm",
|
|
||||||
dtype => {
|
|
||||||
return Err(
|
|
||||||
MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
candle_metal_kernels::call_gemm(
|
|
||||||
&self.device.device,
|
|
||||||
&command_buffer,
|
|
||||||
&self.device.kernels,
|
|
||||||
name,
|
|
||||||
(b, m, n, k),
|
|
||||||
lhs_l.stride(),
|
|
||||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
|
||||||
&self.buffer,
|
|
||||||
rhs_l.stride(),
|
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
|
||||||
&rhs.buffer,
|
|
||||||
&buffer,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
}
|
}
|
||||||
Ok(Self::new(
|
Ok(Self::new(
|
||||||
buffer,
|
buffer,
|
||||||
@ -1884,10 +1858,6 @@ impl BackendDevice for MetalDevice {
|
|||||||
let device = metal::Device::all().swap_remove(ordinal);
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
let use_mlx_mm = match std::env::var("CANDLE_USE_MFA_MM").as_deref() {
|
|
||||||
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => true,
|
|
||||||
Ok(_) => false,
|
|
||||||
};
|
|
||||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||||
[299792458].as_ptr() as *const c_void,
|
[299792458].as_ptr() as *const c_void,
|
||||||
4,
|
4,
|
||||||
@ -1901,7 +1871,6 @@ impl BackendDevice for MetalDevice {
|
|||||||
buffers: Arc::new(RwLock::new(HashMap::new())),
|
buffers: Arc::new(RwLock::new(HashMap::new())),
|
||||||
kernels,
|
kernels,
|
||||||
seed,
|
seed,
|
||||||
use_mlx_mm,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,11 +263,7 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (model, device) = {
|
let (model, device) = {
|
||||||
let dtype = if device.is_cuda() {
|
let dtype = device.bf16_default_to_f32();
|
||||||
DType::BF16
|
|
||||||
} else {
|
|
||||||
DType::F32
|
|
||||||
};
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(&config, vb)?;
|
||||||
(model, device)
|
(model, device)
|
||||||
|
@ -44,66 +44,46 @@ fn run_gemm(f32: bool, n: usize) -> Result<()> {
|
|||||||
);
|
);
|
||||||
(lhs, rhs)
|
(lhs, rhs)
|
||||||
};
|
};
|
||||||
let (dtype, name, sizeof) = if f32 {
|
let (dtype, sizeof) = if f32 {
|
||||||
(GemmDType::F32, "sgemm", core::mem::size_of::<f32>())
|
(GemmDType::F32, core::mem::size_of::<f32>())
|
||||||
} else {
|
} else {
|
||||||
(GemmDType::F16, "hgemm", core::mem::size_of::<f16>())
|
(GemmDType::F16, core::mem::size_of::<f16>())
|
||||||
};
|
};
|
||||||
let output = device.new_buffer((b * m * n * sizeof) as u64, options);
|
let output = device.new_buffer((b * m * n * sizeof) as u64, options);
|
||||||
|
|
||||||
for mlx in [false, true] {
|
let mut sum_dt = 0f64;
|
||||||
let mut sum_dt = 0f64;
|
let mut iters = 0usize;
|
||||||
let mut iters = 0usize;
|
for idx in 0.. {
|
||||||
for idx in 0.. {
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let start_time = std::time::Instant::now();
|
||||||
let start_time = std::time::Instant::now();
|
candle_metal_kernels::call_mlx_gemm(
|
||||||
if mlx {
|
&device,
|
||||||
candle_metal_kernels::call_mlx_gemm(
|
command_buffer,
|
||||||
&device,
|
&kernels,
|
||||||
command_buffer,
|
dtype,
|
||||||
&kernels,
|
(b, m, n, k),
|
||||||
dtype,
|
&[m * k, k, 1],
|
||||||
(b, m, n, k),
|
0,
|
||||||
&[m * k, k, 1],
|
&lhs,
|
||||||
0,
|
&[n * k, n, 1],
|
||||||
&lhs,
|
0,
|
||||||
&[n * k, n, 1],
|
&rhs,
|
||||||
0,
|
&output,
|
||||||
&rhs,
|
)?;
|
||||||
&output,
|
command_buffer.commit();
|
||||||
)?;
|
command_buffer.wait_until_completed();
|
||||||
} else {
|
let dt = start_time.elapsed().as_secs_f64();
|
||||||
candle_metal_kernels::call_gemm(
|
if idx < WARMUP_ITERS {
|
||||||
&device,
|
continue;
|
||||||
command_buffer,
|
}
|
||||||
&kernels,
|
sum_dt += dt;
|
||||||
name,
|
iters += 1;
|
||||||
(b, m, n, k),
|
if sum_dt > MIN_DUR {
|
||||||
&[m * k, k, 1],
|
break;
|
||||||
0,
|
|
||||||
&lhs,
|
|
||||||
&[n * k, n, 1],
|
|
||||||
0,
|
|
||||||
&rhs,
|
|
||||||
&output,
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
let dt = start_time.elapsed().as_secs_f64();
|
|
||||||
if idx < WARMUP_ITERS {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
sum_dt += dt;
|
|
||||||
iters += 1;
|
|
||||||
if sum_dt > MIN_DUR {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
|
|
||||||
let mlx = if mlx { "MLX" } else { "MFA" };
|
|
||||||
println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}");
|
|
||||||
}
|
}
|
||||||
|
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
|
||||||
|
println!("{dtype:?}, {n:6} gflops {gflops:.0}");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -16,8 +16,6 @@ const CAST: &str = include_str!("cast.metal");
|
|||||||
const CONV: &str = include_str!("conv.metal");
|
const CONV: &str = include_str!("conv.metal");
|
||||||
const FILL: &str = include_str!("fill.metal");
|
const FILL: &str = include_str!("fill.metal");
|
||||||
const INDEXING: &str = include_str!("indexing.metal");
|
const INDEXING: &str = include_str!("indexing.metal");
|
||||||
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
|
|
||||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
|
||||||
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||||
const RANDOM: &str = include_str!("random.metal");
|
const RANDOM: &str = include_str!("random.metal");
|
||||||
@ -36,7 +34,6 @@ pub enum Source {
|
|||||||
Fill,
|
Fill,
|
||||||
Gemm,
|
Gemm,
|
||||||
Indexing,
|
Indexing,
|
||||||
Mfa,
|
|
||||||
Quantized,
|
Quantized,
|
||||||
Random,
|
Random,
|
||||||
Reduce,
|
Reduce,
|
||||||
@ -221,7 +218,6 @@ impl Kernels {
|
|||||||
Source::Ternary => TERNARY,
|
Source::Ternary => TERNARY,
|
||||||
Source::Unary => UNARY,
|
Source::Unary => UNARY,
|
||||||
Source::Sdpa => SDPA,
|
Source::Sdpa => SDPA,
|
||||||
Source::Mfa => panic!("Invalid lib"),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,21 +232,11 @@ impl Kernels {
|
|||||||
if let Some(lib) = libraries.get(&source) {
|
if let Some(lib) = libraries.get(&source) {
|
||||||
Ok(lib.clone())
|
Ok(lib.clone())
|
||||||
} else {
|
} else {
|
||||||
let lib = match source {
|
let lib = {
|
||||||
Source::Mfa => {
|
let source_content = self.get_library_source(source);
|
||||||
let source_data = MFA;
|
device
|
||||||
device.new_library_with_data(source_data).map_err(|e| {
|
.new_library_with_source(source_content, &CompileOptions::new())
|
||||||
MetalKernelError::LoadLibraryError(format!(
|
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
|
||||||
))
|
|
||||||
})?
|
|
||||||
}
|
|
||||||
source => {
|
|
||||||
let source_content = self.get_library_source(source);
|
|
||||||
device
|
|
||||||
.new_library_with_source(source_content, &CompileOptions::new())
|
|
||||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
libraries.insert(source, lib.clone());
|
libraries.insert(source, lib.clone());
|
||||||
Ok(lib)
|
Ok(lib)
|
||||||
@ -1471,176 +1457,6 @@ impl ConstantValues {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_gemm(
|
|
||||||
device: &Device,
|
|
||||||
ep: impl EncoderProvider,
|
|
||||||
kernels: &Kernels,
|
|
||||||
name: &'static str,
|
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
|
||||||
lhs_stride: &[usize],
|
|
||||||
lhs_offset: usize,
|
|
||||||
lhs_buffer: &Buffer,
|
|
||||||
rhs_stride: &[usize],
|
|
||||||
rhs_offset: usize,
|
|
||||||
rhs_buffer: &Buffer,
|
|
||||||
output: &Buffer,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
|
||||||
assert!(rhs_stride.len() >= 2);
|
|
||||||
assert!(lhs_stride.len() >= 2);
|
|
||||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
|
||||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
|
||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
|
||||||
// lhs has shape b, m, k
|
|
||||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
|
||||||
// there is a single element.
|
|
||||||
let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
|
||||||
false
|
|
||||||
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
|
||||||
mnk: (m, n, k),
|
|
||||||
})?;
|
|
||||||
};
|
|
||||||
// rhs has shape b, k, n
|
|
||||||
let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
|
||||||
false
|
|
||||||
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
|
||||||
mnk: (m, n, k),
|
|
||||||
})?;
|
|
||||||
};
|
|
||||||
let d_trans = false;
|
|
||||||
let alpha = 1.0f32;
|
|
||||||
let beta = 0.0f32;
|
|
||||||
let batched = b > 1;
|
|
||||||
let fused_activation = false;
|
|
||||||
let fused_bias = false;
|
|
||||||
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
|
||||||
let m_simd = 8;
|
|
||||||
let n_simd = 8;
|
|
||||||
let k_simd = 64;
|
|
||||||
let m_splits = 1;
|
|
||||||
let n_splits = 1;
|
|
||||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
|
||||||
} else {
|
|
||||||
let m_simd = 40;
|
|
||||||
let n_simd = 40;
|
|
||||||
let k_simd = 32;
|
|
||||||
let m_splits = 1;
|
|
||||||
let n_splits = 1;
|
|
||||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
|
||||||
};
|
|
||||||
let constants = Some(ConstantValues::new(vec![
|
|
||||||
(0, Value::USize(m)),
|
|
||||||
(1, Value::USize(n)),
|
|
||||||
(2, Value::USize(k)),
|
|
||||||
(10, Value::Bool(a_trans)),
|
|
||||||
(11, Value::Bool(b_trans)),
|
|
||||||
(13, Value::Bool(d_trans)),
|
|
||||||
(20, Value::F32(alpha)),
|
|
||||||
(21, Value::F32(beta)),
|
|
||||||
(100, Value::Bool(batched)),
|
|
||||||
(101, Value::Bool(fused_activation)),
|
|
||||||
// Garbage
|
|
||||||
(102, Value::Bool(false)),
|
|
||||||
(103, Value::Bool(false)),
|
|
||||||
(113, Value::Bool(false)),
|
|
||||||
(50_000, Value::Bool(false)),
|
|
||||||
// End garbage
|
|
||||||
(200, Value::U16(m_simd)),
|
|
||||||
(201, Value::U16(n_simd)),
|
|
||||||
(202, Value::U16(k_simd)),
|
|
||||||
(210, Value::U16(m_splits)),
|
|
||||||
(211, Value::U16(n_splits)),
|
|
||||||
(50_001, Value::Bool(fused_bias)),
|
|
||||||
]));
|
|
||||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
|
||||||
let m_group = m_simd * m_splits;
|
|
||||||
let n_group = n_simd * n_splits;
|
|
||||||
|
|
||||||
let a_block_length = m_group * k_simd;
|
|
||||||
let b_block_length = k_simd * n_group;
|
|
||||||
|
|
||||||
let mut block_elements = a_block_length + b_block_length;
|
|
||||||
if (m % 8 != 0) && (n % 8 != 0) {
|
|
||||||
let c_block_length = m_group * n_group;
|
|
||||||
block_elements = std::cmp::max(c_block_length, block_elements)
|
|
||||||
}
|
|
||||||
if fused_bias {
|
|
||||||
if d_trans {
|
|
||||||
block_elements = std::cmp::max(block_elements, m_group);
|
|
||||||
} else {
|
|
||||||
block_elements = std::cmp::max(block_elements, n_group);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let bytes = match name {
|
|
||||||
"sgemm" => 4,
|
|
||||||
"hgemm" => 2,
|
|
||||||
"bgemm" => 2,
|
|
||||||
other => {
|
|
||||||
return Err(MetalKernelError::LoadLibraryError(format!(
|
|
||||||
"{other} is not a valid kernel for gemm"
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let block_bytes = block_elements * bytes;
|
|
||||||
|
|
||||||
let encoder = ep.encoder();
|
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
|
||||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
|
||||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
|
||||||
encoder.set_buffer(2, Some(output), 0);
|
|
||||||
// TODO Tensor D
|
|
||||||
|
|
||||||
let grid_z = b;
|
|
||||||
if batched {
|
|
||||||
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
|
||||||
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
|
||||||
let byte_stride_c = m * n * bytes as usize;
|
|
||||||
// TODO byte_stride_d
|
|
||||||
let byte_stride_d = 0;
|
|
||||||
|
|
||||||
let buffer: Vec<u64> = vec![
|
|
||||||
byte_stride_a as _,
|
|
||||||
byte_stride_b as _,
|
|
||||||
byte_stride_c as _,
|
|
||||||
byte_stride_d as _,
|
|
||||||
];
|
|
||||||
encoder.set_bytes(
|
|
||||||
10,
|
|
||||||
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
|
||||||
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let grid_size = MTLSize {
|
|
||||||
width: divide(n, n_group.into()),
|
|
||||||
height: divide(m, m_group.into()),
|
|
||||||
depth: grid_z as NSUInteger,
|
|
||||||
};
|
|
||||||
let group_size = MTLSize {
|
|
||||||
width: 32 * (m_splits as u64) * (n_splits as u64),
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
||||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
||||||
pub enum SdpaDType {
|
pub enum SdpaDType {
|
||||||
BF16,
|
BF16,
|
||||||
|
@ -1046,168 +1046,6 @@ fn where_cond_u32_f32() {
|
|||||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn run_gemm<T: Clone>(
|
|
||||||
name: &'static str,
|
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
|
||||||
lhs: &[T],
|
|
||||||
lhs_stride: &[usize],
|
|
||||||
lhs_offset: usize,
|
|
||||||
rhs: &[T],
|
|
||||||
rhs_stride: &[usize],
|
|
||||||
rhs_offset: usize,
|
|
||||||
) -> Vec<T> {
|
|
||||||
let device = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
|
|
||||||
let lhs = device.new_buffer_with_data(
|
|
||||||
lhs.as_ptr() as *const core::ffi::c_void,
|
|
||||||
std::mem::size_of_val(lhs) as u64,
|
|
||||||
options,
|
|
||||||
);
|
|
||||||
let rhs = device.new_buffer_with_data(
|
|
||||||
rhs.as_ptr() as *const core::ffi::c_void,
|
|
||||||
std::mem::size_of_val(rhs) as u64,
|
|
||||||
options,
|
|
||||||
);
|
|
||||||
let length = b * m * n;
|
|
||||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
|
||||||
call_gemm(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
(b, m, n, k),
|
|
||||||
lhs_stride,
|
|
||||||
lhs_offset,
|
|
||||||
&lhs,
|
|
||||||
rhs_stride,
|
|
||||||
rhs_offset,
|
|
||||||
&rhs,
|
|
||||||
&output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
read_to_vec(&output, length)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn gemm() {
|
|
||||||
let (b, m, n, k) = (1, 2, 4, 3);
|
|
||||||
let lhs_stride = vec![m * k, k, 1];
|
|
||||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
|
||||||
let rhs_stride = vec![n * k, n, 1];
|
|
||||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
|
||||||
let results = run_gemm(
|
|
||||||
"sgemm",
|
|
||||||
(b, m, n, k),
|
|
||||||
&lhs,
|
|
||||||
&lhs_stride,
|
|
||||||
0,
|
|
||||||
&rhs,
|
|
||||||
&rhs_stride,
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
approx(results, 4),
|
|
||||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
|
||||||
);
|
|
||||||
|
|
||||||
let (b, m, n, k) = (2, 2, 4, 3);
|
|
||||||
let lhs_stride = vec![m * k, k, 1];
|
|
||||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
|
||||||
let rhs_stride = vec![n * k, n, 1];
|
|
||||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
|
||||||
let results = run_gemm(
|
|
||||||
"sgemm",
|
|
||||||
(b, m, n, k),
|
|
||||||
&lhs,
|
|
||||||
&lhs_stride,
|
|
||||||
0,
|
|
||||||
&rhs,
|
|
||||||
&rhs_stride,
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
approx(results, 4),
|
|
||||||
vec![
|
|
||||||
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
|
|
||||||
518.0, 548.0, 578.0
|
|
||||||
]
|
|
||||||
);
|
|
||||||
|
|
||||||
// OFFSET
|
|
||||||
let (b, m, n, k) = (2, 2, 4, 3);
|
|
||||||
let lhs_stride = vec![m * k, k, 1];
|
|
||||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
|
||||||
let rhs_stride = vec![n * k, n, 1];
|
|
||||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
|
||||||
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
|
|
||||||
let results = run_gemm(
|
|
||||||
"sgemm",
|
|
||||||
(1, m, n, k),
|
|
||||||
&lhs,
|
|
||||||
&lhs_stride,
|
|
||||||
0,
|
|
||||||
&rhs,
|
|
||||||
&rhs_stride,
|
|
||||||
12 * 4,
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
approx(results, 4),
|
|
||||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
|
||||||
);
|
|
||||||
|
|
||||||
// bgemm sanity test
|
|
||||||
if false {
|
|
||||||
let (b, m, n, k) = (1, 2, 4, 3);
|
|
||||||
let lhs_stride = vec![m * k, k, 1];
|
|
||||||
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
|
||||||
let rhs_stride = vec![n * k, n, 1];
|
|
||||||
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
|
||||||
let results = run_gemm(
|
|
||||||
"bgemm",
|
|
||||||
(b, m, n, k),
|
|
||||||
&lhs,
|
|
||||||
&lhs_stride,
|
|
||||||
0,
|
|
||||||
&rhs,
|
|
||||||
&rhs_stride,
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
approx_bf16(results, 4),
|
|
||||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// hgemm sanity test
|
|
||||||
let (b, m, n, k) = (1, 2, 4, 3);
|
|
||||||
let lhs_stride = vec![m * k, k, 1];
|
|
||||||
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
|
|
||||||
let rhs_stride = vec![n * k, n, 1];
|
|
||||||
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
|
|
||||||
let results = run_gemm(
|
|
||||||
"hgemm",
|
|
||||||
(b, m, n, k),
|
|
||||||
&lhs,
|
|
||||||
&lhs_stride,
|
|
||||||
0,
|
|
||||||
&rhs,
|
|
||||||
&rhs_stride,
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
approx_f16(results, 4),
|
|
||||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn run_mlx_gemm<T: Clone>(
|
fn run_mlx_gemm<T: Clone>(
|
||||||
dtype: GemmDType,
|
dtype: GemmDType,
|
||||||
@ -1258,50 +1096,6 @@ fn run_mlx_gemm<T: Clone>(
|
|||||||
read_to_vec(&output, length)
|
read_to_vec(&output, length)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) {
|
|
||||||
use rand::SeedableRng;
|
|
||||||
use rand_distr::Distribution;
|
|
||||||
|
|
||||||
let mut rng = rand::rngs::StdRng::seed_from_u64(42424242);
|
|
||||||
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
|
|
||||||
|
|
||||||
let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect();
|
|
||||||
let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect();
|
|
||||||
let v1: Vec<f32> = run_mlx_gemm(
|
|
||||||
dtype,
|
|
||||||
(b, m, n, k),
|
|
||||||
&lhs,
|
|
||||||
&[m * k, k, 1],
|
|
||||||
0,
|
|
||||||
&rhs,
|
|
||||||
&[k * n, n, 1],
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
let v2: Vec<f32> = run_gemm(
|
|
||||||
"sgemm",
|
|
||||||
(b, m, n, k),
|
|
||||||
&lhs,
|
|
||||||
&[m * k, k, 1],
|
|
||||||
0,
|
|
||||||
&rhs,
|
|
||||||
&[k * n, n, 1],
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
for (a, b) in v1.iter().zip(v2.iter()) {
|
|
||||||
let diff = (a - b).abs();
|
|
||||||
assert_eq!((diff * 1e4).round(), 0.)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn mlx_vs_mfa() {
|
|
||||||
mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32);
|
|
||||||
mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32);
|
|
||||||
mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32);
|
|
||||||
mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32);
|
|
||||||
mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn mlx_gemm() {
|
fn mlx_gemm() {
|
||||||
let (b, m, n, k) = (1, 2, 4, 3);
|
let (b, m, n, k) = (1, 2, 4, 3);
|
||||||
|
Reference in New Issue
Block a user