mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Re-revert the reverted revision (bf16 gemm metal)
This commit is contained in:
@ -173,8 +173,8 @@ impl Device {
|
|||||||
|
|
||||||
pub fn supports_bf16(&self) -> bool {
|
pub fn supports_bf16(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
Self::Cuda(_) => true,
|
Self::Cuda(_) | Self::Metal(_) => true,
|
||||||
Self::Metal(_) | Self::Cpu => false,
|
Self::Cpu => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1162,27 +1162,25 @@ fn gemm() {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// bgemm sanity test
|
// bgemm sanity test
|
||||||
if false {
|
let (b, m, n, k) = (1, 2, 4, 3);
|
||||||
let (b, m, n, k) = (1, 2, 4, 3);
|
let lhs_stride = vec![m * k, k, 1];
|
||||||
let lhs_stride = vec![m * k, k, 1];
|
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||||
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
let rhs_stride = vec![n * k, n, 1];
|
||||||
let rhs_stride = vec![n * k, n, 1];
|
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||||
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
let results = run_gemm(
|
||||||
let results = run_gemm(
|
"bgemm",
|
||||||
"bgemm",
|
(b, m, n, k),
|
||||||
(b, m, n, k),
|
&lhs,
|
||||||
&lhs,
|
lhs_stride,
|
||||||
lhs_stride,
|
0,
|
||||||
0,
|
&rhs,
|
||||||
&rhs,
|
rhs_stride,
|
||||||
rhs_stride,
|
0,
|
||||||
0,
|
);
|
||||||
);
|
assert_eq!(
|
||||||
assert_eq!(
|
approx_bf16(results, 4),
|
||||||
approx_bf16(results, 4),
|
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
);
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// hgemm sanity test
|
// hgemm sanity test
|
||||||
let (b, m, n, k) = (1, 2, 4, 3);
|
let (b, m, n, k) = (1, 2, 4, 3);
|
||||||
|
Reference in New Issue
Block a user