Use BF16 on metal when possible. (#2378)

This commit is contained in:
Laurent Mazare
2024-08-01 09:48:58 +01:00
committed by GitHub
parent bd80078acf
commit 1ba87a9450
2 changed files with 17 additions and 5 deletions

View File

@ -217,11 +217,7 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::v0_1_8x7b(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());