Use AdamW in the llama2 training. (#308)

This commit is contained in:
Laurent Mazare
2023-08-02 14:14:02 +01:00
committed by GitHub
parent 0902846f25
commit 4f17290ce0

View File

@ -150,12 +150,16 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
let params = candle_nn::ParamsAdamW {
lr: args.learning_rate,
..Default::default()
};
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
sgd.backward_step(&loss)?;
opt.backward_step(&loss)?;
if batch_index > 0 && batch_index % 100 == 0 {
// TODO: Add a way to deactivate the backprop graph tracking when computing the
@ -163,6 +167,9 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let loss = valid_loss(&dataset, &model, args, &device)?;
println!("{batch_index} {loss}");
}
if batch_index > 0 && batch_index % 1000 == 0 {
varmap.save("checkpoint.safetensors")?
}
}
Ok(())
}