mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Enable BF16 on metal. (#2380)
This commit is contained in:
@ -361,10 +361,8 @@ fn main() -> Result<()> {
|
||||
let dtype = match args.dtype {
|
||||
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
|
||||
None => {
|
||||
if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium)
|
||||
&& device.is_cuda()
|
||||
{
|
||||
DType::BF16
|
||||
if args.model == WhichModel::V3 || args.model == WhichModel::V3Medium {
|
||||
device.bf16_default_to_f32()
|
||||
} else {
|
||||
DType::F32
|
||||
}
|
||||
|
Reference in New Issue
Block a user