mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Metal bgemm min changes (#2364)
* Add updated mfa metallib * Add bgemm and tests
This commit is contained in:
@ -19,6 +19,7 @@ const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
@ -1564,6 +1565,7 @@ pub fn call_gemm(
|
||||
let bytes = match name {
|
||||
"sgemm" => 4,
|
||||
"hgemm" => 2,
|
||||
"bgemm" => 2,
|
||||
other => {
|
||||
return Err(MetalKernelError::LoadLibraryError(format!(
|
||||
"{other} is not a valid kernel for gemm"
|
||||
|
Binary file not shown.
@ -1046,6 +1046,7 @@ fn where_cond_u32_f32() {
|
||||
}
|
||||
|
||||
fn run_gemm<T: Clone>(
|
||||
name: &'static str,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &[T],
|
||||
lhs_stride: Vec<usize>,
|
||||
@ -1076,7 +1077,7 @@ fn run_gemm<T: Clone>(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"sgemm",
|
||||
name,
|
||||
(b, m, n, k),
|
||||
&lhs_stride,
|
||||
lhs_offset,
|
||||
@ -1100,7 +1101,16 @@ fn gemm() {
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
||||
let results = run_gemm(
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
@ -1111,7 +1121,16 @@ fn gemm() {
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
||||
let results = run_gemm(
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![
|
||||
@ -1127,11 +1146,62 @@ fn gemm() {
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
|
||||
let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4);
|
||||
let results = run_gemm(
|
||||
"sgemm",
|
||||
(1, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
12 * 4,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
||||
);
|
||||
|
||||
// bgemm sanity test
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let results = run_gemm(
|
||||
"bgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
|
||||
// hgemm sanity test
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||
let results = run_gemm(
|
||||
"hgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_f16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||
|
Reference in New Issue
Block a user