From 1ba87a94505ee52bdd362fd4da11d1ee945ea593 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 1 Aug 2024 09:48:58 +0100 Subject: [PATCH] Use BF16 on metal when possible. (#2378) --- candle-core/src/device.rs | 16 ++++++++++++++++ candle-examples/examples/mixtral/main.rs | 6 +----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 1cd26167..91e56937 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -171,6 +171,22 @@ impl Device { matches!(self, Self::Metal(_)) } + pub fn supports_bf16(&self) -> bool { + match self { + Self::Cuda(_) | Self::Metal(_) => true, + Self::Cpu => false, + } + } + + /// Return `BF16` for devices that support it, otherwise default to `F32`. + pub fn bf16_default_to_f32(&self) -> DType { + if self.supports_bf16() { + DType::BF16 + } else { + DType::F32 + } + } + pub fn cuda_if_available(ordinal: usize) -> Result { if crate::utils::cuda_is_available() { Self::new_cuda(ordinal) diff --git a/candle-examples/examples/mixtral/main.rs b/candle-examples/examples/mixtral/main.rs index fe47e537..8d4cc3fb 100644 --- a/candle-examples/examples/mixtral/main.rs +++ b/candle-examples/examples/mixtral/main.rs @@ -217,11 +217,7 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config = Config::v0_1_8x7b(args.use_flash_attn); let device = candle_examples::device(args.cpu)?; - let dtype = if device.is_cuda() { - DType::BF16 - } else { - DType::F32 - }; + let dtype = device.bf16_default_to_f32(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let model = Model::new(&config, vb)?; println!("loaded the model in {:?}", start.elapsed());