mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Improve the mnist training example. (#276)
* Improve the mnist training example. * Add some initialization routine that can be used for nn. * Proper initialization in the mnist example.
This commit is contained in:
@ -116,21 +116,48 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform_f64(
|
||||
&self,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform<T: crate::FloatDType>(
|
||||
&self,
|
||||
lo: T,
|
||||
up: T,
|
||||
shape: &Shape,
|
||||
) -> Result<Storage> {
|
||||
let lo = lo.to_f64();
|
||||
let up = up.to_f64();
|
||||
self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal_f64(
|
||||
&self,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?;
|
||||
let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?;
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
@ -142,18 +169,7 @@ impl Device {
|
||||
std: T,
|
||||
shape: &Shape,
|
||||
) -> Result<Storage> {
|
||||
let mean = mean.to_f64();
|
||||
let std = std.to_f64();
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
|
||||
}
|
||||
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
|
Reference in New Issue
Block a user