Random initializers. (#128)

* Random initialization.

* CPU rng generation.
This commit is contained in:
Laurent Mazare
2023-07-10 18:26:21 +01:00
committed by GitHub
parent e2807c78a4
commit f29b77ec19
6 changed files with 235 additions and 3 deletions

View File

@ -222,6 +222,58 @@ impl Tensor {
Tensor::zeros(self.shape(), self.dtype(), &self.device())
}
fn rand_uniform_impl<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_uniform(&s, dtype)?;
Ok(from_storage(storage, s, None, is_variable))
}
pub fn rand_uniform<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
Self::rand_uniform_impl(s, dtype, device, false)
}
pub fn rand_uniform_var<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
Self::rand_uniform_impl(s, dtype, device, true)
}
fn rand_normal_impl<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_normal(&s, dtype, mean, std)?;
Ok(from_storage(storage, s, None, is_variable))
}
pub fn rand_normal<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
) -> Result<Self> {
Self::rand_normal_impl(s, dtype, device, mean, std, false)
}
pub fn rand_normal_var<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
) -> Result<Self> {
Self::rand_normal_impl(s, dtype, device, mean, std, true)
}
pub fn new_impl<A: crate::device::NdArray>(
array: A,
shape: Shape,