mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Improve the mnist training example. (#276)
* Improve the mnist training example. * Add some initialization routine that can be used for nn. * Proper initialization in the mnist example.
This commit is contained in:
@ -245,6 +245,20 @@ impl Tensor {
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
pub(crate) fn rand_f64_impl<S: Into<Shape>>(
|
||||
lo: f64,
|
||||
up: f64,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
|
||||
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
|
||||
lo: T,
|
||||
@ -268,6 +282,20 @@ impl Tensor {
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
pub(crate) fn randn_f64_impl<S: Into<Shape>>(
|
||||
mean: f64,
|
||||
std: f64,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values sampled from a normal distribution with the
|
||||
/// specified `mean` and standard deviation `std`.
|
||||
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
|
||||
@ -1448,6 +1476,16 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
||||
/// copied.
|
||||
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
||||
let shape = self.shape().clone();
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
||||
}
|
||||
|
||||
// TODO: Do we want to allow target shape using -1 on some dimensions?
|
||||
/// Reshape returns a tensor with the target shape provided that the number of elements of the
|
||||
/// original tensor is the same.
|
||||
|
Reference in New Issue
Block a user