mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
65 lines
2.5 KiB
Rust
65 lines
2.5 KiB
Rust
use crate::model::{Cache, Config, Llama};
|
|
use candle::{DType, Device, Result};
|
|
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
|
|
use candle_nn::Optimizer;
|
|
|
|
fn valid_loss(
|
|
dataset: &Dataset,
|
|
model: &Llama,
|
|
args: &crate::TrainingCmd,
|
|
device: &Device,
|
|
) -> Result<f64> {
|
|
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
|
let mut sum_ce = 0f64;
|
|
let mut cnt = 0usize;
|
|
for inp_tgt in batch_iter.take(50) {
|
|
let (inp, tgt) = inp_tgt?;
|
|
let logits = model.forward(&inp, 0)?;
|
|
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
|
sum_ce += loss.to_vec0::<f32>()? as f64;
|
|
cnt += 1;
|
|
}
|
|
Ok(sum_ce / cnt as f64)
|
|
}
|
|
|
|
pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|
let device = candle_examples::device(common_args.cpu)?;
|
|
let dataset = Dataset::new(&args.pretokenized_dir)?;
|
|
println!(
|
|
"loaded dataset, train: {} files, valid: {} files",
|
|
dataset.train_tokens(),
|
|
dataset.valid_tokens()
|
|
);
|
|
let varmap = candle_nn::VarMap::new();
|
|
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
|
let config = Config::tiny_15m();
|
|
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
|
|
|
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
|
let model = Llama::load(vb, &cache, config)?;
|
|
let params = candle_nn::ParamsAdamW {
|
|
lr: args.learning_rate,
|
|
..Default::default()
|
|
};
|
|
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
|
for (batch_index, batch) in batch_iter.enumerate() {
|
|
let (inp, tgt) = batch?;
|
|
let logits = model.forward(&inp, 0)?;
|
|
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
|
opt.backward_step(&loss)?;
|
|
|
|
if batch_index > 0 && batch_index % 100 == 0 {
|
|
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
|
// validation loss.
|
|
let loss = valid_loss(&dataset, &model, args, &device)?;
|
|
println!("{batch_index} {loss}");
|
|
}
|
|
if batch_index > 0 && batch_index % 1000 == 0 {
|
|
varmap.save("checkpoint.safetensors")?
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|