mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Fix randn cpu (#382)
* Change distributions Standard generates in [0, 1), Normal is correct. * Add test Not sure if this is the best place to put the test * Remove unnecessary use
This commit is contained in:
@ -2070,35 +2070,45 @@ impl BackendDevice for CpuDevice {
|
||||
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
|
||||
DType::BF16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let std = bf16::from_f64(std);
|
||||
let mean = bf16::from_f64(mean);
|
||||
let normal = match rand_distr::Normal::new(mean, std) {
|
||||
Ok(n) => n,
|
||||
Err(e) => Err(Error::wrap(e))?,
|
||||
};
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
|
||||
data.push(bf16::from_f64(normal.sample(&mut rng)))
|
||||
}
|
||||
Ok(CpuStorage::BF16(data))
|
||||
}
|
||||
DType::F16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let std = f16::from_f64(std);
|
||||
let mean = f16::from_f64(mean);
|
||||
let normal = match rand_distr::Normal::new(mean, std) {
|
||||
Ok(n) => n,
|
||||
Err(e) => Err(Error::wrap(e))?,
|
||||
};
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
|
||||
data.push(f16::from_f64(normal.sample(&mut rng)))
|
||||
}
|
||||
Ok(CpuStorage::F16(data))
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let std = std as f32;
|
||||
let mean = mean as f32;
|
||||
let normal = match rand_distr::Normal::new(mean, std) {
|
||||
Ok(n) => n,
|
||||
Err(e) => Err(Error::wrap(e))?,
|
||||
};
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
|
||||
data.push(normal.sample(&mut rng) as f32)
|
||||
}
|
||||
Ok(CpuStorage::F32(data))
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let normal = match rand_distr::Normal::new(mean, std) {
|
||||
Ok(n) => n,
|
||||
Err(e) => Err(Error::wrap(e))?,
|
||||
};
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
|
||||
data.push(normal.sample(&mut rng))
|
||||
}
|
||||
Ok(CpuStorage::F64(data))
|
||||
}
|
||||
|
Reference in New Issue
Block a user