Random initializers. (#128)

* Random initialization.

* CPU rng generation.
This commit is contained in:
Laurent Mazare
2023-07-10 18:26:21 +01:00
committed by GitHub
parent e2807c78a4
commit f29b77ec19
6 changed files with 235 additions and 3 deletions

View File

@ -895,6 +895,66 @@ impl CpuStorage {
MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
}
pub(crate) fn rand_uniform(shape: &Shape, dtype: DType) -> Result<Self> {
use rand::prelude::*;
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal"))
}
DType::F32 => {
let mut data = Vec::new();
data.reserve(elem_count);
let uniform = rand::distributions::Uniform::new(0f32, 1f32);
for _i in 0..elem_count {
data.push(rng.sample::<f32, _>(uniform))
}
Ok(Self::F32(data))
}
DType::F64 => {
let mut data = Vec::new();
data.reserve(elem_count);
let uniform = rand::distributions::Uniform::new(0f64, 1f64);
for _i in 0..elem_count {
data.push(rng.sample::<f64, _>(uniform))
}
Ok(Self::F64(data))
}
}
}
pub(crate) fn rand_normal(shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<Self> {
use rand::prelude::*;
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal"))
}
DType::F32 => {
let mut data = Vec::new();
data.reserve(elem_count);
let std = std as f32;
let mean = mean as f32;
for _i in 0..elem_count {
data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
}
Ok(Self::F32(data))
}
DType::F64 => {
let mut data = Vec::new();
data.reserve(elem_count);
for _i in 0..elem_count {
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
}
Ok(Self::F64(data))
}
}
}
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {