mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add a hack for generating random uniform/normal for f16/bf16. (#1228)
This commit is contained in:
@ -185,11 +185,17 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
|
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||||
|
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||||
|
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
|
||||||
|
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||||
|
} else {
|
||||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn rand_uniform<T: crate::FloatDType>(
|
pub(crate) fn rand_uniform<T: crate::FloatDType>(
|
||||||
&self,
|
&self,
|
||||||
@ -213,11 +219,17 @@ impl Device {
|
|||||||
Ok(Storage::Cpu(storage))
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
|
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||||
|
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||||
|
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
|
||||||
|
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||||
|
} else {
|
||||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn rand_normal<T: crate::FloatDType>(
|
pub(crate) fn rand_normal<T: crate::FloatDType>(
|
||||||
&self,
|
&self,
|
||||||
|
Reference in New Issue
Block a user