mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Metal bgemm min changes (#2364)
* Add updated mfa metallib * Add bgemm and tests
This commit is contained in:
@ -19,6 +19,7 @@ const CAST: &str = include_str!("cast.metal");
|
|||||||
const CONV: &str = include_str!("conv.metal");
|
const CONV: &str = include_str!("conv.metal");
|
||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &str = include_str!("reduce.metal");
|
||||||
const RANDOM: &str = include_str!("random.metal");
|
const RANDOM: &str = include_str!("random.metal");
|
||||||
|
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
|
||||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||||
const SORT: &str = include_str!("sort.metal");
|
const SORT: &str = include_str!("sort.metal");
|
||||||
@ -1564,6 +1565,7 @@ pub fn call_gemm(
|
|||||||
let bytes = match name {
|
let bytes = match name {
|
||||||
"sgemm" => 4,
|
"sgemm" => 4,
|
||||||
"hgemm" => 2,
|
"hgemm" => 2,
|
||||||
|
"bgemm" => 2,
|
||||||
other => {
|
other => {
|
||||||
return Err(MetalKernelError::LoadLibraryError(format!(
|
return Err(MetalKernelError::LoadLibraryError(format!(
|
||||||
"{other} is not a valid kernel for gemm"
|
"{other} is not a valid kernel for gemm"
|
||||||
|
Binary file not shown.
@ -1046,6 +1046,7 @@ fn where_cond_u32_f32() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn run_gemm<T: Clone>(
|
fn run_gemm<T: Clone>(
|
||||||
|
name: &'static str,
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
lhs: &[T],
|
lhs: &[T],
|
||||||
lhs_stride: Vec<usize>,
|
lhs_stride: Vec<usize>,
|
||||||
@ -1076,7 +1077,7 @@ fn run_gemm<T: Clone>(
|
|||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
"sgemm",
|
name,
|
||||||
(b, m, n, k),
|
(b, m, n, k),
|
||||||
&lhs_stride,
|
&lhs_stride,
|
||||||
lhs_offset,
|
lhs_offset,
|
||||||
@ -1100,7 +1101,16 @@ fn gemm() {
|
|||||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||||
let rhs_stride = vec![n * k, n, 1];
|
let rhs_stride = vec![n * k, n, 1];
|
||||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
let results = run_gemm(
|
||||||
|
"sgemm",
|
||||||
|
(b, m, n, k),
|
||||||
|
&lhs,
|
||||||
|
lhs_stride,
|
||||||
|
0,
|
||||||
|
&rhs,
|
||||||
|
rhs_stride,
|
||||||
|
0,
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx(results, 4),
|
approx(results, 4),
|
||||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||||
@ -1111,7 +1121,16 @@ fn gemm() {
|
|||||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||||
let rhs_stride = vec![n * k, n, 1];
|
let rhs_stride = vec![n * k, n, 1];
|
||||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
let results = run_gemm(
|
||||||
|
"sgemm",
|
||||||
|
(b, m, n, k),
|
||||||
|
&lhs,
|
||||||
|
lhs_stride,
|
||||||
|
0,
|
||||||
|
&rhs,
|
||||||
|
rhs_stride,
|
||||||
|
0,
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx(results, 4),
|
approx(results, 4),
|
||||||
vec![
|
vec![
|
||||||
@ -1127,11 +1146,62 @@ fn gemm() {
|
|||||||
let rhs_stride = vec![n * k, n, 1];
|
let rhs_stride = vec![n * k, n, 1];
|
||||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||||
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
|
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
|
||||||
let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4);
|
let results = run_gemm(
|
||||||
|
"sgemm",
|
||||||
|
(1, m, n, k),
|
||||||
|
&lhs,
|
||||||
|
lhs_stride,
|
||||||
|
0,
|
||||||
|
&rhs,
|
||||||
|
rhs_stride,
|
||||||
|
12 * 4,
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx(results, 4),
|
approx(results, 4),
|
||||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// bgemm sanity test
|
||||||
|
let (b, m, n, k) = (1, 2, 4, 3);
|
||||||
|
let lhs_stride = vec![m * k, k, 1];
|
||||||
|
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||||
|
let rhs_stride = vec![n * k, n, 1];
|
||||||
|
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||||
|
let results = run_gemm(
|
||||||
|
"bgemm",
|
||||||
|
(b, m, n, k),
|
||||||
|
&lhs,
|
||||||
|
lhs_stride,
|
||||||
|
0,
|
||||||
|
&rhs,
|
||||||
|
rhs_stride,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx_bf16(results, 4),
|
||||||
|
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||||
|
);
|
||||||
|
|
||||||
|
// hgemm sanity test
|
||||||
|
let (b, m, n, k) = (1, 2, 4, 3);
|
||||||
|
let lhs_stride = vec![m * k, k, 1];
|
||||||
|
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||||
|
let rhs_stride = vec![n * k, n, 1];
|
||||||
|
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||||
|
let results = run_gemm(
|
||||||
|
"hgemm",
|
||||||
|
(b, m, n, k),
|
||||||
|
&lhs,
|
||||||
|
lhs_stride,
|
||||||
|
0,
|
||||||
|
&rhs,
|
||||||
|
rhs_stride,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
approx_f16(results, 4),
|
||||||
|
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||||
|
Reference in New Issue
Block a user