mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Further randn tweaks: use the appropriate rng rather than the f64 one, some cleanup. (#383)
This commit is contained in:
@ -2070,43 +2070,34 @@ impl BackendDevice for CpuDevice {
|
||||
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
|
||||
DType::BF16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let normal = match rand_distr::Normal::new(mean, std) {
|
||||
Ok(n) => n,
|
||||
Err(e) => Err(Error::wrap(e))?,
|
||||
};
|
||||
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
|
||||
.map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(bf16::from_f64(normal.sample(&mut rng)))
|
||||
data.push(normal.sample(&mut rng))
|
||||
}
|
||||
Ok(CpuStorage::BF16(data))
|
||||
}
|
||||
DType::F16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let normal = match rand_distr::Normal::new(mean, std) {
|
||||
Ok(n) => n,
|
||||
Err(e) => Err(Error::wrap(e))?,
|
||||
};
|
||||
let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
|
||||
.map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(f16::from_f64(normal.sample(&mut rng)))
|
||||
data.push(normal.sample(&mut rng))
|
||||
}
|
||||
Ok(CpuStorage::F16(data))
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let normal = match rand_distr::Normal::new(mean, std) {
|
||||
Ok(n) => n,
|
||||
Err(e) => Err(Error::wrap(e))?,
|
||||
};
|
||||
let normal =
|
||||
rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(normal.sample(&mut rng) as f32)
|
||||
data.push(normal.sample(&mut rng))
|
||||
}
|
||||
Ok(CpuStorage::F32(data))
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let normal = match rand_distr::Normal::new(mean, std) {
|
||||
Ok(n) => n,
|
||||
Err(e) => Err(Error::wrap(e))?,
|
||||
};
|
||||
let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(normal.sample(&mut rng))
|
||||
}
|
||||
|
Reference in New Issue
Block a user