mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Simplify Tensor::randn. (#255)
* Simplify Tensor::randn. * Also switch Tensor::rand to use a generic dtype. * Support sampling for f16. * Cleanup.
This commit is contained in:
@ -34,25 +34,23 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn rand<S: Into<Shape>>(
|
||||
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
|
||||
lo: T,
|
||||
up: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
) -> Result<Self> {
|
||||
let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?;
|
||||
let inner = Tensor::rand_impl(lo, up, s, device, true)?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn randn<S: Into<Shape>>(
|
||||
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
|
||||
mean: T,
|
||||
std: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Self> {
|
||||
let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?;
|
||||
let inner = Tensor::randn_impl(mean, std, s, device, true)?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user