From e178aacead09b0f73a541b11b085e2b263448bbc Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 6 Aug 2024 12:37:36 +0200 Subject: [PATCH] Re-revert the reverted revision (bf16 gemm metal) --- candle-core/src/device.rs | 4 ++-- candle-metal-kernels/src/tests.rs | 40 +++++++++++++++---------------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 45f6554b..91e56937 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -173,8 +173,8 @@ impl Device { pub fn supports_bf16(&self) -> bool { match self { - Self::Cuda(_) => true, - Self::Metal(_) | Self::Cpu => false, + Self::Cuda(_) | Self::Metal(_) => true, + Self::Cpu => false, } } diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 30c454af..f70f773a 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1162,27 +1162,25 @@ fn gemm() { ); // bgemm sanity test - if false { - let (b, m, n, k) = (1, 2, 4, 3); - let lhs_stride = vec![m * k, k, 1]; - let lhs: Vec = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); - let rhs_stride = vec![n * k, n, 1]; - let rhs: Vec = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); - let results = run_gemm( - "bgemm", - (b, m, n, k), - &lhs, - lhs_stride, - 0, - &rhs, - rhs_stride, - 0, - ); - assert_eq!( - approx_bf16(results, 4), - vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] - ); - } + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); + let results = run_gemm( + "bgemm", + (b, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 0, + ); + assert_eq!( + approx_bf16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); // hgemm sanity test let (b, m, n, k) = (1, 2, 4, 3);