Simplify Tensor::randn. (#255)

* Simplify Tensor::randn.

* Also switch Tensor::rand to use a generic dtype.

* Support sampling for f16.

* Cleanup.
This commit is contained in:
Laurent Mazare
2023-07-27 07:40:36 +01:00
committed by GitHub
parent 89ba005962
commit 6475bfadfe
10 changed files with 111 additions and 72 deletions

View File

@ -34,25 +34,23 @@ impl Var {
Ok(Self(inner))
}
pub fn rand<S: Into<Shape>>(
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,
s: S,
dtype: DType,
device: &Device,
lo: f64,
up: f64,
) -> Result<Self> {
let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?;
let inner = Tensor::rand_impl(lo, up, s, device, true)?;
Ok(Self(inner))
}
pub fn randn<S: Into<Shape>>(
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
mean: T,
std: T,
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
) -> Result<Self> {
let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?;
let inner = Tensor::randn_impl(mean, std, s, device, true)?;
Ok(Self(inner))
}