mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a flag to set the number of epochs in the mnist training (#283)
* Add a flag to change the number of epochs for the mnist training. * Increase the learning rate for the MLP.
This commit is contained in:
@ -163,11 +163,16 @@ impl Model for Mlp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn training_loop<M: Model>(
|
struct TrainingArgs {
|
||||||
m: candle_nn::vision::Dataset,
|
|
||||||
learning_rate: f64,
|
learning_rate: f64,
|
||||||
load: Option<String>,
|
load: Option<String>,
|
||||||
save: Option<String>,
|
save: Option<String>,
|
||||||
|
epochs: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn training_loop<M: Model>(
|
||||||
|
m: candle_nn::vision::Dataset,
|
||||||
|
args: &TrainingArgs,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let dev = candle::Device::cuda_if_available(0)?;
|
let dev = candle::Device::cuda_if_available(0)?;
|
||||||
|
|
||||||
@ -181,17 +186,17 @@ fn training_loop<M: Model>(
|
|||||||
let mut vs = VarStore::new(DType::F32, dev.clone());
|
let mut vs = VarStore::new(DType::F32, dev.clone());
|
||||||
let model = M::new(vs.clone())?;
|
let model = M::new(vs.clone())?;
|
||||||
|
|
||||||
if let Some(load) = load {
|
if let Some(load) = &args.load {
|
||||||
println!("loading weights from {load}");
|
println!("loading weights from {load}");
|
||||||
vs.load(&load)?
|
vs.load(load)?
|
||||||
}
|
}
|
||||||
|
|
||||||
let all_vars = vs.all_vars();
|
let all_vars = vs.all_vars();
|
||||||
let all_vars = all_vars.iter().collect::<Vec<_>>();
|
let all_vars = all_vars.iter().collect::<Vec<_>>();
|
||||||
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
|
let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate);
|
||||||
let test_images = m.test_images.to_device(&dev)?;
|
let test_images = m.test_images.to_device(&dev)?;
|
||||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||||
for epoch in 1..200 {
|
for epoch in 1..args.epochs {
|
||||||
let logits = model.forward(&train_images)?;
|
let logits = model.forward(&train_images)?;
|
||||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||||
@ -211,9 +216,9 @@ fn training_loop<M: Model>(
|
|||||||
100. * test_accuracy
|
100. * test_accuracy
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if let Some(save) = save {
|
if let Some(save) = &args.save {
|
||||||
println!("saving trained weights in {save}");
|
println!("saving trained weights in {save}");
|
||||||
vs.save(&save)?
|
vs.save(save)?
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -232,6 +237,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
learning_rate: Option<f64>,
|
learning_rate: Option<f64>,
|
||||||
|
|
||||||
|
#[arg(long, default_value_t = 200)]
|
||||||
|
epochs: usize,
|
||||||
|
|
||||||
/// The file where to save the trained weights, in safetensors format.
|
/// The file where to save the trained weights, in safetensors format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
save: Option<String>,
|
save: Option<String>,
|
||||||
@ -250,12 +258,18 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
println!("test-images: {:?}", m.test_images.shape());
|
println!("test-images: {:?}", m.test_images.shape());
|
||||||
println!("test-labels: {:?}", m.test_labels.shape());
|
println!("test-labels: {:?}", m.test_labels.shape());
|
||||||
|
|
||||||
|
let default_learning_rate = match args.model {
|
||||||
|
WhichModel::Linear => 1.,
|
||||||
|
WhichModel::Mlp => 0.05,
|
||||||
|
};
|
||||||
|
let training_args = TrainingArgs {
|
||||||
|
epochs: args.epochs,
|
||||||
|
learning_rate: args.learning_rate.unwrap_or(default_learning_rate),
|
||||||
|
load: args.load,
|
||||||
|
save: args.save,
|
||||||
|
};
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::Linear => {
|
WhichModel::Linear => training_loop::<LinearModel>(m, &training_args),
|
||||||
training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.load, args.save)
|
WhichModel::Mlp => training_loop::<Mlp>(m, &training_args),
|
||||||
}
|
|
||||||
WhichModel::Mlp => {
|
|
||||||
training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.load, args.save)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user