Add a flag to set the number of epochs in the mnist training (#283)

* Add a flag to change the number of epochs for the mnist training.

* Increase the learning rate for the MLP.
This commit is contained in:
Laurent Mazare
2023-07-31 10:32:14 +01:00
committed by GitHub
parent 67834119fc
commit 62a9b03715

View File

@ -163,11 +163,16 @@ impl Model for Mlp {
}
}
fn training_loop<M: Model>(
m: candle_nn::vision::Dataset,
struct TrainingArgs {
learning_rate: f64,
load: Option<String>,
save: Option<String>,
epochs: usize,
}
fn training_loop<M: Model>(
m: candle_nn::vision::Dataset,
args: &TrainingArgs,
) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
@ -181,17 +186,17 @@ fn training_loop<M: Model>(
let mut vs = VarStore::new(DType::F32, dev.clone());
let model = M::new(vs.clone())?;
if let Some(load) = load {
if let Some(load) = &args.load {
println!("loading weights from {load}");
vs.load(&load)?
vs.load(load)?
}
let all_vars = vs.all_vars();
let all_vars = all_vars.iter().collect::<Vec<_>>();
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate);
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
for epoch in 1..200 {
for epoch in 1..args.epochs {
let logits = model.forward(&train_images)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_labels)?;
@ -211,9 +216,9 @@ fn training_loop<M: Model>(
100. * test_accuracy
);
}
if let Some(save) = save {
if let Some(save) = &args.save {
println!("saving trained weights in {save}");
vs.save(&save)?
vs.save(save)?
}
Ok(())
}
@ -232,6 +237,9 @@ struct Args {
#[arg(long)]
learning_rate: Option<f64>,
#[arg(long, default_value_t = 200)]
epochs: usize,
/// The file where to save the trained weights, in safetensors format.
#[arg(long)]
save: Option<String>,
@ -250,12 +258,18 @@ pub fn main() -> anyhow::Result<()> {
println!("test-images: {:?}", m.test_images.shape());
println!("test-labels: {:?}", m.test_labels.shape());
let default_learning_rate = match args.model {
WhichModel::Linear => 1.,
WhichModel::Mlp => 0.05,
};
let training_args = TrainingArgs {
epochs: args.epochs,
learning_rate: args.learning_rate.unwrap_or(default_learning_rate),
load: args.load,
save: args.save,
};
match args.model {
WhichModel::Linear => {
training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.load, args.save)
}
WhichModel::Mlp => {
training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.load, args.save)
}
WhichModel::Linear => training_loop::<LinearModel>(m, &training_args),
WhichModel::Mlp => training_loop::<Mlp>(m, &training_args),
}
}