mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Re-revert the reverted revision (bf16 gemm metal)
This commit is contained in:
@ -173,8 +173,8 @@ impl Device {
|
|||||||
|
|
||||||
pub fn supports_bf16(&self) -> bool {
|
pub fn supports_bf16(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
Self::Cuda(_) => true,
|
Self::Cuda(_) | Self::Metal(_) => true,
|
||||||
Self::Metal(_) | Self::Cpu => false,
|
Self::Cpu => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1162,7 +1162,6 @@ fn gemm() {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// bgemm sanity test
|
// bgemm sanity test
|
||||||
if false {
|
|
||||||
let (b, m, n, k) = (1, 2, 4, 3);
|
let (b, m, n, k) = (1, 2, 4, 3);
|
||||||
let lhs_stride = vec![m * k, k, 1];
|
let lhs_stride = vec![m * k, k, 1];
|
||||||
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||||
@ -1182,7 +1181,6 @@ fn gemm() {
|
|||||||
approx_bf16(results, 4),
|
approx_bf16(results, 4),
|
||||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
// hgemm sanity test
|
// hgemm sanity test
|
||||||
let (b, m, n, k) = (1, 2, 4, 3);
|
let (b, m, n, k) = (1, 2, 4, 3);
|
||||||
|
Reference in New Issue
Block a user