mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use AdamW in the llama2 training. (#308)
This commit is contained in:
@ -150,12 +150,16 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
|
||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
|
||||
let params = candle_nn::ParamsAdamW {
|
||||
lr: args.learning_rate,
|
||||
..Default::default()
|
||||
};
|
||||
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
||||
for (batch_index, batch) in batch_iter.enumerate() {
|
||||
let (inp, tgt) = batch?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
opt.backward_step(&loss)?;
|
||||
|
||||
if batch_index > 0 && batch_index % 100 == 0 {
|
||||
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
||||
@ -163,6 +167,9 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let loss = valid_loss(&dataset, &model, args, &device)?;
|
||||
println!("{batch_index} {loss}");
|
||||
}
|
||||
if batch_index > 0 && batch_index % 1000 == 0 {
|
||||
varmap.save("checkpoint.safetensors")?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user