Improve the mnist training example. (#276)

* Improve the mnist training example.

* Add some initialization routine that can be used for nn.

* Proper initialization in the mnist example.
This commit is contained in:
Laurent Mazare
2023-07-29 16:28:22 +01:00
committed by GitHub
parent bedcef64dc
commit 16c33383eb
6 changed files with 198 additions and 44 deletions

View File

@ -116,21 +116,48 @@ impl Device {
}
}
pub(crate) fn rand_uniform_f64(
&self,
lo: f64,
up: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}
pub(crate) fn rand_uniform<T: crate::FloatDType>(
&self,
lo: T,
up: T,
shape: &Shape,
) -> Result<Storage> {
let lo = lo.to_f64();
let up = up.to_f64();
self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
}
pub(crate) fn rand_normal_f64(
&self,
mean: f64,
std: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?;
let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?;
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
@ -142,18 +169,7 @@ impl Device {
std: T,
shape: &Shape,
) -> Result<Storage> {
let mean = mean.to_f64();
let std = std.to_f64();
match self {
Device::Cpu => {
let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {