diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index d50cb944..937510c7 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -163,11 +163,16 @@ impl Model for Mlp { } } -fn training_loop( - m: candle_nn::vision::Dataset, +struct TrainingArgs { learning_rate: f64, load: Option, save: Option, + epochs: usize, +} + +fn training_loop( + m: candle_nn::vision::Dataset, + args: &TrainingArgs, ) -> anyhow::Result<()> { let dev = candle::Device::cuda_if_available(0)?; @@ -181,17 +186,17 @@ fn training_loop( let mut vs = VarStore::new(DType::F32, dev.clone()); let model = M::new(vs.clone())?; - if let Some(load) = load { + if let Some(load) = &args.load { println!("loading weights from {load}"); - vs.load(&load)? + vs.load(load)? } let all_vars = vs.all_vars(); let all_vars = all_vars.iter().collect::>(); - let sgd = candle_nn::SGD::new(&all_vars, learning_rate); + let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate); let test_images = m.test_images.to_device(&dev)?; let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?; - for epoch in 1..200 { + for epoch in 1..args.epochs { let logits = model.forward(&train_images)?; let log_sm = ops::log_softmax(&logits, D::Minus1)?; let loss = loss::nll(&log_sm, &train_labels)?; @@ -211,9 +216,9 @@ fn training_loop( 100. * test_accuracy ); } - if let Some(save) = save { + if let Some(save) = &args.save { println!("saving trained weights in {save}"); - vs.save(&save)? + vs.save(save)? } Ok(()) } @@ -232,6 +237,9 @@ struct Args { #[arg(long)] learning_rate: Option, + #[arg(long, default_value_t = 200)] + epochs: usize, + /// The file where to save the trained weights, in safetensors format. #[arg(long)] save: Option, @@ -250,12 +258,18 @@ pub fn main() -> anyhow::Result<()> { println!("test-images: {:?}", m.test_images.shape()); println!("test-labels: {:?}", m.test_labels.shape()); + let default_learning_rate = match args.model { + WhichModel::Linear => 1., + WhichModel::Mlp => 0.05, + }; + let training_args = TrainingArgs { + epochs: args.epochs, + learning_rate: args.learning_rate.unwrap_or(default_learning_rate), + load: args.load, + save: args.save, + }; match args.model { - WhichModel::Linear => { - training_loop::(m, args.learning_rate.unwrap_or(1.), args.load, args.save) - } - WhichModel::Mlp => { - training_loop::(m, args.learning_rate.unwrap_or(0.01), args.load, args.save) - } + WhichModel::Linear => training_loop::(m, &training_args), + WhichModel::Mlp => training_loop::(m, &training_args), } }