Add a hack for generating random uniform/normal for f16/bf16. (#1228)

This commit is contained in:
Laurent Mazare
2023-10-31 21:27:59 +01:00
committed by GitHub
parent c12ad45562
commit 36fb84f038

View File

@ -185,8 +185,14 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}
}
@ -213,8 +219,14 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
}
}