mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Use AdamW in the llama2 training. (#308)
This commit is contained in:
@ -150,12 +150,16 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
|
|
||||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Llama::load(vb, &cache, config)?;
|
||||||
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
|
let params = candle_nn::ParamsAdamW {
|
||||||
|
lr: args.learning_rate,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
||||||
for (batch_index, batch) in batch_iter.enumerate() {
|
for (batch_index, batch) in batch_iter.enumerate() {
|
||||||
let (inp, tgt) = batch?;
|
let (inp, tgt) = batch?;
|
||||||
let logits = model.forward(&inp, 0)?;
|
let logits = model.forward(&inp, 0)?;
|
||||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||||
sgd.backward_step(&loss)?;
|
opt.backward_step(&loss)?;
|
||||||
|
|
||||||
if batch_index > 0 && batch_index % 100 == 0 {
|
if batch_index > 0 && batch_index % 100 == 0 {
|
||||||
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
||||||
@ -163,6 +167,9 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
let loss = valid_loss(&dataset, &model, args, &device)?;
|
let loss = valid_loss(&dataset, &model, args, &device)?;
|
||||||
println!("{batch_index} {loss}");
|
println!("{batch_index} {loss}");
|
||||||
}
|
}
|
||||||
|
if batch_index > 0 && batch_index % 1000 == 0 {
|
||||||
|
varmap.save("checkpoint.safetensors")?
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user