mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a flag to set the number of epochs in the mnist training (#283)
* Add a flag to change the number of epochs for the mnist training. * Increase the learning rate for the MLP.
This commit is contained in:
@ -163,11 +163,16 @@ impl Model for Mlp {
|
||||
}
|
||||
}
|
||||
|
||||
fn training_loop<M: Model>(
|
||||
m: candle_nn::vision::Dataset,
|
||||
struct TrainingArgs {
|
||||
learning_rate: f64,
|
||||
load: Option<String>,
|
||||
save: Option<String>,
|
||||
epochs: usize,
|
||||
}
|
||||
|
||||
fn training_loop<M: Model>(
|
||||
m: candle_nn::vision::Dataset,
|
||||
args: &TrainingArgs,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = candle::Device::cuda_if_available(0)?;
|
||||
|
||||
@ -181,17 +186,17 @@ fn training_loop<M: Model>(
|
||||
let mut vs = VarStore::new(DType::F32, dev.clone());
|
||||
let model = M::new(vs.clone())?;
|
||||
|
||||
if let Some(load) = load {
|
||||
if let Some(load) = &args.load {
|
||||
println!("loading weights from {load}");
|
||||
vs.load(&load)?
|
||||
vs.load(load)?
|
||||
}
|
||||
|
||||
let all_vars = vs.all_vars();
|
||||
let all_vars = all_vars.iter().collect::<Vec<_>>();
|
||||
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
|
||||
let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate);
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
for epoch in 1..200 {
|
||||
for epoch in 1..args.epochs {
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
@ -211,9 +216,9 @@ fn training_loop<M: Model>(
|
||||
100. * test_accuracy
|
||||
);
|
||||
}
|
||||
if let Some(save) = save {
|
||||
if let Some(save) = &args.save {
|
||||
println!("saving trained weights in {save}");
|
||||
vs.save(&save)?
|
||||
vs.save(save)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -232,6 +237,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
learning_rate: Option<f64>,
|
||||
|
||||
#[arg(long, default_value_t = 200)]
|
||||
epochs: usize,
|
||||
|
||||
/// The file where to save the trained weights, in safetensors format.
|
||||
#[arg(long)]
|
||||
save: Option<String>,
|
||||
@ -250,12 +258,18 @@ pub fn main() -> anyhow::Result<()> {
|
||||
println!("test-images: {:?}", m.test_images.shape());
|
||||
println!("test-labels: {:?}", m.test_labels.shape());
|
||||
|
||||
let default_learning_rate = match args.model {
|
||||
WhichModel::Linear => 1.,
|
||||
WhichModel::Mlp => 0.05,
|
||||
};
|
||||
let training_args = TrainingArgs {
|
||||
epochs: args.epochs,
|
||||
learning_rate: args.learning_rate.unwrap_or(default_learning_rate),
|
||||
load: args.load,
|
||||
save: args.save,
|
||||
};
|
||||
match args.model {
|
||||
WhichModel::Linear => {
|
||||
training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.load, args.save)
|
||||
}
|
||||
WhichModel::Mlp => {
|
||||
training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.load, args.save)
|
||||
}
|
||||
WhichModel::Linear => training_loop::<LinearModel>(m, &training_args),
|
||||
WhichModel::Mlp => training_loop::<Mlp>(m, &training_args),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user