Use BF16 on metal when possible. (#2378)

This commit is contained in:
Laurent Mazare
2024-08-01 09:48:58 +01:00
committed by GitHub
parent bd80078acf
commit 1ba87a9450
2 changed files with 17 additions and 5 deletions

View File

@ -171,6 +171,22 @@ impl Device {
matches!(self, Self::Metal(_))
}
pub fn supports_bf16(&self) -> bool {
match self {
Self::Cuda(_) | Self::Metal(_) => true,
Self::Cpu => false,
}
}
/// Return `BF16` for devices that support it, otherwise default to `F32`.
pub fn bf16_default_to_f32(&self) -> DType {
if self.supports_bf16() {
DType::BF16
} else {
DType::F32
}
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
if crate::utils::cuda_is_available() {
Self::new_cuda(ordinal)