Improve the mnist training example. (#276)

* Improve the mnist training example.

* Add some initialization routine that can be used for nn.

* Proper initialization in the mnist example.
This commit is contained in:
Laurent Mazare
2023-07-29 16:28:22 +01:00
committed by GitHub
parent bedcef64dc
commit 16c33383eb
6 changed files with 198 additions and 44 deletions

View File

@ -34,6 +34,33 @@ impl Var {
Ok(Self(inner))
}
pub fn from_tensor(t: &Tensor) -> Result<Self> {
let inner = t.make_var()?;
Ok(Self(inner))
}
pub fn rand_f64<S: Into<Shape>>(
lo: f64,
up: f64,
s: S,
dtype: DType,
device: &Device,
) -> Result<Self> {
let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
Ok(Self(inner))
}
pub fn randn_f64<S: Into<Shape>>(
mean: f64,
std: f64,
s: S,
dtype: DType,
device: &Device,
) -> Result<Self> {
let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
Ok(Self(inner))
}
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,