Add training for the llama2.c example (#296)

* Rework the commands and run inference by default.

* Add the training module and load the training dataset.

* Random dataset iterator.

* Proper valid-loss computation.

* Compute the evaluation loss.

* Add more substance to the training loop.
This commit is contained in:
Laurent Mazare
2023-08-01 17:23:07 +01:00
committed by GitHub
parent babee9f011
commit a27239f3d9
6 changed files with 227 additions and 9 deletions

View File

@ -228,3 +228,11 @@ macro_rules! bail {
return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
};
}
pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
match (r1, r2) {
(Ok(r1), Ok(r2)) => Ok((r1, r2)),
(Err(e), _) => Err(e),
(_, Err(e)) => Err(e),
}
}

View File

@ -44,7 +44,7 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
mod error;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "mkl")]

View File

@ -26,8 +26,9 @@ half = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true }
hf-hub = { workspace = true}
clap = { workspace = true }
hf-hub = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }

View File

@ -4,6 +4,7 @@
extern crate intel_mkl_src;
mod model;
mod training;
mod weights;
use clap::{Parser, Subcommand};
@ -64,19 +65,33 @@ struct EvaluationCmd {
which_model: String,
}
#[derive(Parser, Debug, Clone)]
pub struct TrainingCmd {
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
/// script from llama2.c https://github.com/karpathy/llama2.c
#[arg(long)]
pretokenized_dir: String,
#[arg(long, default_value_t = 32)]
batch_size: usize,
#[arg(long, default_value_t = 0.001)]
learning_rate: f64,
}
#[derive(Subcommand, Debug, Clone)]
enum Task {
Inference(InferenceCmd),
Evaluation(EvaluationCmd),
Training,
Eval(EvaluationCmd),
Train(TrainingCmd),
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
pub struct Args {
/// The task to be performed, inference, training or evaluation.
#[command(subcommand)]
task: Task,
task: Option<Task>,
/// Run on CPU rather than on GPU.
#[arg(long)]
@ -104,9 +119,19 @@ impl Args {
fn main() -> anyhow::Result<()> {
let args = Args::parse();
match &args.task {
Task::Inference(cmd) => run_inference(cmd, &args)?,
Task::Evaluation(cmd) => run_eval(cmd, &args)?,
Task::Training => todo!(),
None => {
let cmd = InferenceCmd {
temperature: None,
prompt: "".to_string(),
config: None,
model_id: "karpathy/tinyllamas".to_string(),
which_model: "stories15M.bin".to_string(),
};
run_inference(&cmd, &args)?
}
Some(Task::Inference(cmd)) => run_inference(cmd, &args)?,
Some(Task::Eval(cmd)) => run_eval(cmd, &args)?,
Some(Task::Train(cmd)) => training::run(cmd, &args)?,
}
Ok(())
}
@ -202,6 +227,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;

View File

@ -15,6 +15,21 @@ pub struct Config {
pub norm_eps: f64,
}
impl Config {
pub fn tiny() -> Self {
Self {
dim: 288,
hidden_dim: 768,
n_layers: 6,
n_heads: 6,
n_kv_heads: 6,
vocab_size: 32000,
seq_len: 256,
norm_eps: 1e-5,
}
}
}
#[derive(Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,

View File

@ -0,0 +1,168 @@
#![allow(dead_code)]
#![allow(unused)]
use crate::model::{Cache, Config, Llama};
use candle::{DType, Device, Result, Tensor};
pub struct Dataset {
valid_tokens: Vec<memmap2::Mmap>,
train_tokens: Vec<memmap2::Mmap>,
}
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
let file = std::fs::File::open(p)?;
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
Ok(mmap)
}
impl Dataset {
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
let dir = dir.as_ref();
let mut bin_files = vec![];
for file in std::fs::read_dir(dir)?.flatten() {
let file = file.path();
if let Some(extension) = file.extension() {
if extension == "bin" {
bin_files.push(file)
}
}
}
if bin_files.len() < 2 {
candle::bail!("found less than two bin files in {:?}", dir)
}
bin_files.sort();
let valid_tokens = mmap_file(&bin_files[0])?;
let train_tokens = bin_files[1..]
.iter()
.map(mmap_file)
.collect::<Result<Vec<_>>>()?;
Ok(Self {
valid_tokens: vec![valid_tokens],
train_tokens,
})
}
}
struct DatasetRandomIter<'a> {
all_tokens: &'a [memmap2::Mmap],
tokens: Vec<&'a memmap2::Mmap>,
current_tokens: &'a memmap2::Mmap,
indexes_in_bytes: Vec<usize>,
seq_len: usize,
device: Device,
}
impl<'a> DatasetRandomIter<'a> {
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
use rand::seq::SliceRandom;
use rand::thread_rng;
let all_tokens = if valid {
&ds.valid_tokens
} else {
&ds.train_tokens
};
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
tokens.shuffle(&mut thread_rng());
let current_tokens = tokens.pop().unwrap();
let seq_len_in_bytes = seq_len * 2;
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
indexes_in_bytes.shuffle(&mut thread_rng());
Self {
all_tokens,
tokens,
current_tokens,
indexes_in_bytes,
seq_len,
device,
}
}
}
impl<'a> Iterator for DatasetRandomIter<'a> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {
use byteorder::{LittleEndian, ReadBytesExt};
use rand::seq::SliceRandom;
use rand::thread_rng;
let seq_len = self.seq_len;
if self.indexes_in_bytes.is_empty() {
if self.tokens.is_empty() {
self.tokens = self.all_tokens.iter().collect();
self.tokens.shuffle(&mut thread_rng());
}
self.current_tokens = self.tokens.pop().unwrap();
let seq_len_in_bytes = self.seq_len * 2;
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
self.indexes_in_bytes.shuffle(&mut thread_rng());
}
let start_idx = self.indexes_in_bytes.pop().unwrap();
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
let mut tokens = vec![0u16; bytes.len() / 2];
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
return Some(Err(err.into()));
}
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
let targets = Tensor::new(&tokens[1..], &self.device);
Some(candle::error::zip(inputs, targets))
}
}
fn valid_loss(
dataset: &Dataset,
model: &Llama,
args: &crate::TrainingCmd,
device: &Device,
) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
let mut sum_ce = 0f64;
let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
sum_ce += loss.to_vec0::<f32>()? as f64;
cnt += 1;
}
Ok(sum_ce / cnt as f64)
}
pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
let dataset = Dataset::new(&args.pretokenized_dir)?;
println!(
"loaded dataset, train: {} files, valid: {} files",
dataset.train_tokens.len(),
dataset.valid_tokens.len()
);
let vb = candle_nn::VarBuilder::zeros(DType::F32, &device);
let config = Config::tiny();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
let all_vars = vec![]; // TODO: Propagate the variables from the VarBuilder to here.
let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate);
for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
sgd.backward_step(&loss)?;
if batch_index > 0 && batch_index % 100 == 0 {
// TODO: Add a way to deactivate the backprop graph tracking when computing the
// validation loss.
let loss = valid_loss(&dataset, &model, args, &device)?;
println!("{batch_index} {loss}");
}
}
Ok(())
}