Further randn tweaks: use the appropriate rng rather than the f64 one, some cleanup. (#383)

This commit is contained in:
Laurent Mazare
2023-08-10 06:48:19 +02:00
committed by GitHub
parent 3bbc08a8df
commit c7f92f985e
2 changed files with 21 additions and 37 deletions

View File

@ -2070,43 +2070,34 @@ impl BackendDevice for CpuDevice {
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let normal = match rand_distr::Normal::new(mean, std) {
Ok(n) => n,
Err(e) => Err(Error::wrap(e))?,
};
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(bf16::from_f64(normal.sample(&mut rng)))
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::BF16(data))
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
let normal = match rand_distr::Normal::new(mean, std) {
Ok(n) => n,
Err(e) => Err(Error::wrap(e))?,
};
let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(f16::from_f64(normal.sample(&mut rng)))
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F16(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let normal = match rand_distr::Normal::new(mean, std) {
Ok(n) => n,
Err(e) => Err(Error::wrap(e))?,
};
let normal =
rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(normal.sample(&mut rng) as f32)
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F32(data))
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
let normal = match rand_distr::Normal::new(mean, std) {
Ok(n) => n,
Err(e) => Err(Error::wrap(e))?,
};
let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(normal.sample(&mut rng))
}