Fix randn cpu (#382)

* Change distributions

Standard generates in [0, 1), Normal is correct.

* Add test

Not sure if this is the best place to put  the test

* Remove unnecessary use
This commit is contained in:
Lei
2023-08-10 00:33:44 -04:00
committed by GitHub
parent 25ec2d9f6b
commit 3bbc08a8df
4 changed files with 40 additions and 10 deletions

View File

@ -2070,35 +2070,45 @@ impl BackendDevice for CpuDevice {
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let std = bf16::from_f64(std);
let mean = bf16::from_f64(mean);
let normal = match rand_distr::Normal::new(mean, std) {
Ok(n) => n,
Err(e) => Err(Error::wrap(e))?,
};
for _i in 0..elem_count {
data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
data.push(bf16::from_f64(normal.sample(&mut rng)))
}
Ok(CpuStorage::BF16(data))
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
let std = f16::from_f64(std);
let mean = f16::from_f64(mean);
let normal = match rand_distr::Normal::new(mean, std) {
Ok(n) => n,
Err(e) => Err(Error::wrap(e))?,
};
for _i in 0..elem_count {
data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
data.push(f16::from_f64(normal.sample(&mut rng)))
}
Ok(CpuStorage::F16(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let std = std as f32;
let mean = mean as f32;
let normal = match rand_distr::Normal::new(mean, std) {
Ok(n) => n,
Err(e) => Err(Error::wrap(e))?,
};
for _i in 0..elem_count {
data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
data.push(normal.sample(&mut rng) as f32)
}
Ok(CpuStorage::F32(data))
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
let normal = match rand_distr::Normal::new(mean, std) {
Ok(n) => n,
Err(e) => Err(Error::wrap(e))?,
};
for _i in 0..elem_count {
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F64(data))
}