mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Use BF16 on metal when possible. (#2378)
This commit is contained in:
@ -217,11 +217,7 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::v0_1_8x7b(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
Reference in New Issue
Block a user