Simplify Tensor::randn. (#255)

* Simplify Tensor::randn.

* Also switch Tensor::rand to use a generic dtype.

* Support sampling for f16.

* Cleanup.
This commit is contained in:
Laurent Mazare
2023-07-27 07:40:36 +01:00
committed by GitHub
parent 89ba005962
commit 6475bfadfe
10 changed files with 111 additions and 72 deletions

View File

@ -232,55 +232,51 @@ impl Tensor {
Tensor::zeros(self.shape(), self.dtype(), self.device())
}
pub(crate) fn rand_impl<S: Into<Shape>>(
pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,
s: S,
dtype: DType,
device: &Device,
lo: f64,
up: f64,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_uniform(&s, dtype, lo, up)?;
let storage = device.rand_uniform(lo, up, &s)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
pub fn rand<S: Into<Shape>>(
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,
s: S,
dtype: DType,
device: &Device,
lo: f64,
up: f64,
) -> Result<Self> {
Self::rand_impl(s, dtype, device, lo, up, false)
Self::rand_impl(lo, up, s, device, false)
}
pub(crate) fn randn_impl<S: Into<Shape>>(
pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
mean: T,
std: T,
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_normal(&s, dtype, mean, std)?;
let storage = device.rand_normal(mean, std, &s)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
/// Creates a new tensor initialized with values sampled from a normal distribution with the
/// specified `mean` and standard deviation `std`.
pub fn randn<S: Into<Shape>>(
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
mean: T,
std: T,
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
) -> Result<Self> {
Self::randn_impl(s, dtype, device, mean, std, false)
Self::randn_impl(mean, std, s, device, false)
}
pub(crate) fn new_impl<A: crate::device::NdArray>(