From a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 1 Aug 2023 17:23:07 +0100 Subject: [PATCH] Add training for the llama2.c example (#296) * Rework the commands and run inference by default. * Add the training module and load the training dataset. * Random dataset iterator. * Proper valid-loss computation. * Compute the evaluation loss. * Add more substance to the training loop. --- candle-core/src/error.rs | 8 + candle-core/src/lib.rs | 2 +- candle-examples/Cargo.toml | 3 +- candle-examples/examples/llama2-c/main.rs | 40 ++++- candle-examples/examples/llama2-c/model.rs | 15 ++ candle-examples/examples/llama2-c/training.rs | 168 ++++++++++++++++++ 6 files changed, 227 insertions(+), 9 deletions(-) create mode 100644 candle-examples/examples/llama2-c/training.rs diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 30d06239..35a33032 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -228,3 +228,11 @@ macro_rules! bail { return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt()) }; } + +pub fn zip(r1: Result, r2: Result) -> Result<(T, U)> { + match (r1, r2) { + (Ok(r1), Ok(r2)) => Ok((r1, r2)), + (Err(e), _) => Err(e), + (_, Err(e)) => Err(e), + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 95cc189c..52244052 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -44,7 +44,7 @@ mod device; pub mod display; mod dtype; mod dummy_cuda_backend; -mod error; +pub mod error; mod indexer; pub mod layout; #[cfg(feature = "mkl")] diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index ff28c646..0ec67942 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -26,8 +26,9 @@ half = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } byteorder = { workspace = true } -hf-hub = { workspace = true} clap = { workspace = true } +hf-hub = { workspace = true } +memmap2 = { workspace = true } rand = { workspace = true } tokenizers = { workspace = true, features = ["onig"] } tracing = { workspace = true } diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index c02c65b9..8b64fdd2 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -4,6 +4,7 @@ extern crate intel_mkl_src; mod model; +mod training; mod weights; use clap::{Parser, Subcommand}; @@ -64,19 +65,33 @@ struct EvaluationCmd { which_model: String, } +#[derive(Parser, Debug, Clone)] +pub struct TrainingCmd { + /// A directory with the pre-tokenized dataset in the format generated by the tinystories.py + /// script from llama2.c https://github.com/karpathy/llama2.c + #[arg(long)] + pretokenized_dir: String, + + #[arg(long, default_value_t = 32)] + batch_size: usize, + + #[arg(long, default_value_t = 0.001)] + learning_rate: f64, +} + #[derive(Subcommand, Debug, Clone)] enum Task { Inference(InferenceCmd), - Evaluation(EvaluationCmd), - Training, + Eval(EvaluationCmd), + Train(TrainingCmd), } #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] -struct Args { +pub struct Args { /// The task to be performed, inference, training or evaluation. #[command(subcommand)] - task: Task, + task: Option, /// Run on CPU rather than on GPU. #[arg(long)] @@ -104,9 +119,19 @@ impl Args { fn main() -> anyhow::Result<()> { let args = Args::parse(); match &args.task { - Task::Inference(cmd) => run_inference(cmd, &args)?, - Task::Evaluation(cmd) => run_eval(cmd, &args)?, - Task::Training => todo!(), + None => { + let cmd = InferenceCmd { + temperature: None, + prompt: "".to_string(), + config: None, + model_id: "karpathy/tinyllamas".to_string(), + which_model: "stories15M.bin".to_string(), + }; + run_inference(&cmd, &args)? + } + Some(Task::Inference(cmd)) => run_inference(cmd, &args)?, + Some(Task::Eval(cmd)) => run_eval(cmd, &args)?, + Some(Task::Train(cmd)) => training::run(cmd, &args)?, } Ok(()) } @@ -202,6 +227,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let mut file = std::fs::File::open(config_path)?; let config = Config::from_reader(&mut file)?; + println!("{config:?}"); let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let vb = weights.var_builder(&config, &device)?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?; diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 618bf67c..4e7015dd 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -15,6 +15,21 @@ pub struct Config { pub norm_eps: f64, } +impl Config { + pub fn tiny() -> Self { + Self { + dim: 288, + hidden_dim: 768, + n_layers: 6, + n_heads: 6, + n_kv_heads: 6, + vocab_size: 32000, + seq_len: 256, + norm_eps: 1e-5, + } + } +} + #[derive(Clone)] pub struct Cache { masks: Arc>>, diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs new file mode 100644 index 00000000..196ba9a8 --- /dev/null +++ b/candle-examples/examples/llama2-c/training.rs @@ -0,0 +1,168 @@ +#![allow(dead_code)] +#![allow(unused)] +use crate::model::{Cache, Config, Llama}; +use candle::{DType, Device, Result, Tensor}; + +pub struct Dataset { + valid_tokens: Vec, + train_tokens: Vec, +} + +fn mmap_file(p: &std::path::PathBuf) -> Result { + let file = std::fs::File::open(p)?; + let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? }; + Ok(mmap) +} + +impl Dataset { + pub fn new>(dir: P) -> Result { + let dir = dir.as_ref(); + let mut bin_files = vec![]; + for file in std::fs::read_dir(dir)?.flatten() { + let file = file.path(); + if let Some(extension) = file.extension() { + if extension == "bin" { + bin_files.push(file) + } + } + } + if bin_files.len() < 2 { + candle::bail!("found less than two bin files in {:?}", dir) + } + bin_files.sort(); + let valid_tokens = mmap_file(&bin_files[0])?; + let train_tokens = bin_files[1..] + .iter() + .map(mmap_file) + .collect::>>()?; + Ok(Self { + valid_tokens: vec![valid_tokens], + train_tokens, + }) + } +} + +struct DatasetRandomIter<'a> { + all_tokens: &'a [memmap2::Mmap], + tokens: Vec<&'a memmap2::Mmap>, + current_tokens: &'a memmap2::Mmap, + indexes_in_bytes: Vec, + seq_len: usize, + device: Device, +} + +impl<'a> DatasetRandomIter<'a> { + pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self { + use rand::seq::SliceRandom; + use rand::thread_rng; + + let all_tokens = if valid { + &ds.valid_tokens + } else { + &ds.train_tokens + }; + let mut tokens = all_tokens.iter().collect::>(); + tokens.shuffle(&mut thread_rng()); + let current_tokens = tokens.pop().unwrap(); + let seq_len_in_bytes = seq_len * 2; + let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes) + .step_by(seq_len_in_bytes) + .collect::>(); + indexes_in_bytes.shuffle(&mut thread_rng()); + Self { + all_tokens, + tokens, + current_tokens, + indexes_in_bytes, + seq_len, + device, + } + } +} + +impl<'a> Iterator for DatasetRandomIter<'a> { + type Item = Result<(Tensor, Tensor)>; + + fn next(&mut self) -> Option { + use byteorder::{LittleEndian, ReadBytesExt}; + use rand::seq::SliceRandom; + use rand::thread_rng; + + let seq_len = self.seq_len; + if self.indexes_in_bytes.is_empty() { + if self.tokens.is_empty() { + self.tokens = self.all_tokens.iter().collect(); + self.tokens.shuffle(&mut thread_rng()); + } + self.current_tokens = self.tokens.pop().unwrap(); + let seq_len_in_bytes = self.seq_len * 2; + self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes) + .step_by(seq_len_in_bytes) + .collect::>(); + self.indexes_in_bytes.shuffle(&mut thread_rng()); + } + let start_idx = self.indexes_in_bytes.pop().unwrap(); + let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)]; + let mut tokens = vec![0u16; bytes.len() / 2]; + if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::(&mut tokens) { + return Some(Err(err.into())); + } + let tokens = tokens.into_iter().map(|v| v as u32).collect::>(); + let inputs = Tensor::new(&tokens[..seq_len], &self.device); + let targets = Tensor::new(&tokens[1..], &self.device); + Some(candle::error::zip(inputs, targets)) + } +} + +fn valid_loss( + dataset: &Dataset, + model: &Llama, + args: &crate::TrainingCmd, + device: &Device, +) -> Result { + let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone()); + let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size); + let mut sum_ce = 0f64; + let mut cnt = 0usize; + for inp_tgt in batch_iter.take(50) { + let (inp, tgt) = inp_tgt?; + let logits = model.forward(&inp, 0)?; + let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; + sum_ce += loss.to_vec0::()? as f64; + cnt += 1; + } + Ok(sum_ce / cnt as f64) +} + +pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { + let device = candle_examples::device(common_args.cpu)?; + let dataset = Dataset::new(&args.pretokenized_dir)?; + println!( + "loaded dataset, train: {} files, valid: {} files", + dataset.train_tokens.len(), + dataset.valid_tokens.len() + ); + let vb = candle_nn::VarBuilder::zeros(DType::F32, &device); + let config = Config::tiny(); + let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); + let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size); + + let cache = Cache::new(false, &config, vb.pp("rot"))?; + let model = Llama::load(vb, &cache, config)?; + let all_vars = vec![]; // TODO: Propagate the variables from the VarBuilder to here. + let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate); + for (batch_index, batch) in batch_iter.enumerate() { + let (inp, tgt) = batch?; + let logits = model.forward(&inp, 0)?; + let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; + sgd.backward_step(&loss)?; + + if batch_index > 0 && batch_index % 100 == 0 { + // TODO: Add a way to deactivate the backprop graph tracking when computing the + // validation loss. + let loss = valid_loss(&dataset, &model, args, &device)?; + println!("{batch_index} {loss}"); + } + } + Ok(()) +}