diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs index 92aa90e6..e55c686c 100644 --- a/candle-examples/examples/llama2-c/training.rs +++ b/candle-examples/examples/llama2-c/training.rs @@ -150,12 +150,16 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { let cache = Cache::new(false, &config, vb.pp("rot"))?; let model = Llama::load(vb, &cache, config)?; - let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate); + let params = candle_nn::ParamsAdamW { + lr: args.learning_rate, + ..Default::default() + }; + let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?; for (batch_index, batch) in batch_iter.enumerate() { let (inp, tgt) = batch?; let logits = model.forward(&inp, 0)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; - sgd.backward_step(&loss)?; + opt.backward_step(&loss)?; if batch_index > 0 && batch_index % 100 == 0 { // TODO: Add a way to deactivate the backprop graph tracking when computing the @@ -163,6 +167,9 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { let loss = valid_loss(&dataset, &model, args, &device)?; println!("{batch_index} {loss}"); } + if batch_index > 0 && batch_index % 1000 == 0 { + varmap.save("checkpoint.safetensors")? + } } Ok(()) }