Enable BF16 on metal. (#2380)

This commit is contained in:
Laurent Mazare
2024-08-01 10:05:07 +01:00
committed by GitHub
parent ce90287f45
commit 957d604a78
2 changed files with 3 additions and 4 deletions

View File

@ -361,10 +361,8 @@ fn main() -> Result<()> {
let dtype = match args.dtype {
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
None => {
if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium)
&& device.is_cuda()
{
DType::BF16
if args.model == WhichModel::V3 || args.model == WhichModel::V3Medium {
device.bf16_default_to_f32()
} else {
DType::F32
}