mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Simplify Tensor::randn. (#255)
* Simplify Tensor::randn. * Also switch Tensor::rand to use a generic dtype. * Support sampling for f16. * Cleanup.
This commit is contained in:
@ -232,55 +232,51 @@ impl Tensor {
|
||||
Tensor::zeros(self.shape(), self.dtype(), self.device())
|
||||
}
|
||||
|
||||
pub(crate) fn rand_impl<S: Into<Shape>>(
|
||||
pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
|
||||
lo: T,
|
||||
up: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_uniform(&s, dtype, lo, up)?;
|
||||
let storage = device.rand_uniform(lo, up, &s)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
|
||||
pub fn rand<S: Into<Shape>>(
|
||||
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
|
||||
lo: T,
|
||||
up: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
) -> Result<Self> {
|
||||
Self::rand_impl(s, dtype, device, lo, up, false)
|
||||
Self::rand_impl(lo, up, s, device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn randn_impl<S: Into<Shape>>(
|
||||
pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
|
||||
mean: T,
|
||||
std: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_normal(&s, dtype, mean, std)?;
|
||||
let storage = device.rand_normal(mean, std, &s)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values sampled from a normal distribution with the
|
||||
/// specified `mean` and standard deviation `std`.
|
||||
pub fn randn<S: Into<Shape>>(
|
||||
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
|
||||
mean: T,
|
||||
std: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Self> {
|
||||
Self::randn_impl(s, dtype, device, mean, std, false)
|
||||
Self::randn_impl(mean, std, s, device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn new_impl<A: crate::device::NdArray>(
|
||||
|
Reference in New Issue
Block a user