diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 162e3f2b..661bdd2a 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -2,8 +2,8 @@ mod benchmarks; use criterion::criterion_main; criterion_main!( - benchmarks::affine::benches, + //benchmarks::affine::benches, benchmarks::matmul::benches, - benchmarks::random::benches, - benchmarks::where_cond::benches + //benchmarks::random::benches, + //benchmarks::where_cond::benches ); diff --git a/candle-core/benches/benchmarks/matmul.rs b/candle-core/benches/benchmarks/matmul.rs index 9d67e642..fa19ecfa 100644 --- a/candle-core/benches/benchmarks/matmul.rs +++ b/candle-core/benches/benchmarks/matmul.rs @@ -13,11 +13,11 @@ fn run_bench(c: &mut Criterion, device: &Device) { let n = 2048; let k = 2048; - let dtype = DType::F32; + let dtype = DType::BF16; let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap(); let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap(); - let flops = b * m * n * k; + let flops = b * m * n * k * dtype.size_in_bytes(); let mut group = c.benchmark_group(device.bench_name("matmul")); group.throughput(Throughput::Bytes(flops as u64)); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6e1ecc5e..abd647af 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1254,6 +1254,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "sgemm", DType::F16 => "hgemm", + DType::BF16 => "bgemm", dtype => { return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 33bc3453..f76af4cb 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1340,6 +1340,7 @@ pub fn call_gemm( let bytes = match name { "sgemm" => 4, "hgemm" => 2, + "bgemm" => 2, other => { return Err(MetalKernelError::LoadLibraryError(format!( "{other} is not a valid kernel for gemm" diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib index 1e2d1acf..57634be8 100644 Binary files a/candle-metal-kernels/src/libMetalFlashAttention.metallib and b/candle-metal-kernels/src/libMetalFlashAttention.metallib differ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 459c8edb..0da8619c 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -857,7 +857,20 @@ fn where_cond() { assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); } -fn run_gemm( +trait Gemmable: Clone { + const gemm_name: &'static str; +} +impl Gemmable for f32 { + const gemm_name: &'static str = "sgemm"; +} +impl Gemmable for f16 { + const gemm_name: &'static str = "hgemm"; +} +impl Gemmable for bf16 { + const gemm_name: &'static str = "bgemm"; +} + +fn run_gemm( (b, m, n, k): (usize, usize, usize, usize), lhs: &[T], lhs_stride: Vec, @@ -866,6 +879,7 @@ fn run_gemm( rhs_stride: Vec, rhs_offset: usize, ) -> Vec { + let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); @@ -888,7 +902,7 @@ fn run_gemm( &device, command_buffer, &kernels, - "sgemm", + T::gemm_name, (b, m, n, k), &lhs_stride, lhs_offset, @@ -909,23 +923,23 @@ fn run_gemm( fn gemm() { let (b, m, n, k) = (1, 2, 4, 3); let lhs_stride = vec![m * k, k, 1]; - let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let lhs: Vec = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); let rhs_stride = vec![n * k, n, 1]; - let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let rhs: Vec = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); assert_eq!( - approx(results, 4), + approx_bf16(results, 4), vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] ); let (b, m, n, k) = (2, 2, 4, 3); let lhs_stride = vec![m * k, k, 1]; - let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let lhs: Vec = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); let rhs_stride = vec![n * k, n, 1]; - let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let rhs: Vec = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); assert_eq!( - approx(results, 4), + approx_bf16(results, 4), vec![ 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0, 518.0, 548.0, 578.0 @@ -935,13 +949,13 @@ fn gemm() { // OFFSET let (b, m, n, k) = (2, 2, 4, 3); let lhs_stride = vec![m * k, k, 1]; - let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let lhs: Vec = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); let rhs_stride = vec![n * k, n, 1]; - let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let rhs: Vec = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32 let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4); assert_eq!( - approx(results, 4), + approx_bf16(results, 4), vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] ); }