mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Random initializers. (#128)
* Random initialization. * CPU rng generation.
This commit is contained in:
@ -222,6 +222,58 @@ impl Tensor {
|
||||
Tensor::zeros(self.shape(), self.dtype(), &self.device())
|
||||
}
|
||||
|
||||
fn rand_uniform_impl<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_uniform(&s, dtype)?;
|
||||
Ok(from_storage(storage, s, None, is_variable))
|
||||
}
|
||||
|
||||
pub fn rand_uniform<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
Self::rand_uniform_impl(s, dtype, device, false)
|
||||
}
|
||||
|
||||
pub fn rand_uniform_var<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
Self::rand_uniform_impl(s, dtype, device, true)
|
||||
}
|
||||
|
||||
fn rand_normal_impl<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_normal(&s, dtype, mean, std)?;
|
||||
Ok(from_storage(storage, s, None, is_variable))
|
||||
}
|
||||
|
||||
pub fn rand_normal<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Self> {
|
||||
Self::rand_normal_impl(s, dtype, device, mean, std, false)
|
||||
}
|
||||
|
||||
pub fn rand_normal_var<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Self> {
|
||||
Self::rand_normal_impl(s, dtype, device, mean, std, true)
|
||||
}
|
||||
|
||||
pub fn new_impl<A: crate::device::NdArray>(
|
||||
array: A,
|
||||
shape: Shape,
|
||||
|
Reference in New Issue
Block a user