mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add training for the llama2.c example (#296)
* Rework the commands and run inference by default. * Add the training module and load the training dataset. * Random dataset iterator. * Proper valid-loss computation. * Compute the evaluation loss. * Add more substance to the training loop.
This commit is contained in:
@ -26,8 +26,9 @@ half = { workspace = true, optional = true }
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
hf-hub = { workspace = true}
|
||||
clap = { workspace = true }
|
||||
hf-hub = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
|
@ -4,6 +4,7 @@
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod model;
|
||||
mod training;
|
||||
mod weights;
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
@ -64,19 +65,33 @@ struct EvaluationCmd {
|
||||
which_model: String,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
pub struct TrainingCmd {
|
||||
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
|
||||
/// script from llama2.c https://github.com/karpathy/llama2.c
|
||||
#[arg(long)]
|
||||
pretokenized_dir: String,
|
||||
|
||||
#[arg(long, default_value_t = 32)]
|
||||
batch_size: usize,
|
||||
|
||||
#[arg(long, default_value_t = 0.001)]
|
||||
learning_rate: f64,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Task {
|
||||
Inference(InferenceCmd),
|
||||
Evaluation(EvaluationCmd),
|
||||
Training,
|
||||
Eval(EvaluationCmd),
|
||||
Train(TrainingCmd),
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
pub struct Args {
|
||||
/// The task to be performed, inference, training or evaluation.
|
||||
#[command(subcommand)]
|
||||
task: Task,
|
||||
task: Option<Task>,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
@ -104,9 +119,19 @@ impl Args {
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
match &args.task {
|
||||
Task::Inference(cmd) => run_inference(cmd, &args)?,
|
||||
Task::Evaluation(cmd) => run_eval(cmd, &args)?,
|
||||
Task::Training => todo!(),
|
||||
None => {
|
||||
let cmd = InferenceCmd {
|
||||
temperature: None,
|
||||
prompt: "".to_string(),
|
||||
config: None,
|
||||
model_id: "karpathy/tinyllamas".to_string(),
|
||||
which_model: "stories15M.bin".to_string(),
|
||||
};
|
||||
run_inference(&cmd, &args)?
|
||||
}
|
||||
Some(Task::Inference(cmd)) => run_inference(cmd, &args)?,
|
||||
Some(Task::Eval(cmd)) => run_eval(cmd, &args)?,
|
||||
Some(Task::Train(cmd)) => training::run(cmd, &args)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -202,6 +227,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("{config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
|
@ -15,6 +15,21 @@ pub struct Config {
|
||||
pub norm_eps: f64,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn tiny() -> Self {
|
||||
Self {
|
||||
dim: 288,
|
||||
hidden_dim: 768,
|
||||
n_layers: 6,
|
||||
n_heads: 6,
|
||||
n_kv_heads: 6,
|
||||
vocab_size: 32000,
|
||||
seq_len: 256,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
|
168
candle-examples/examples/llama2-c/training.rs
Normal file
168
candle-examples/examples/llama2-c/training.rs
Normal file
@ -0,0 +1,168 @@
|
||||
#![allow(dead_code)]
|
||||
#![allow(unused)]
|
||||
use crate::model::{Cache, Config, Llama};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
pub struct Dataset {
|
||||
valid_tokens: Vec<memmap2::Mmap>,
|
||||
train_tokens: Vec<memmap2::Mmap>,
|
||||
}
|
||||
|
||||
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
||||
Ok(mmap)
|
||||
}
|
||||
|
||||
impl Dataset {
|
||||
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
|
||||
let dir = dir.as_ref();
|
||||
let mut bin_files = vec![];
|
||||
for file in std::fs::read_dir(dir)?.flatten() {
|
||||
let file = file.path();
|
||||
if let Some(extension) = file.extension() {
|
||||
if extension == "bin" {
|
||||
bin_files.push(file)
|
||||
}
|
||||
}
|
||||
}
|
||||
if bin_files.len() < 2 {
|
||||
candle::bail!("found less than two bin files in {:?}", dir)
|
||||
}
|
||||
bin_files.sort();
|
||||
let valid_tokens = mmap_file(&bin_files[0])?;
|
||||
let train_tokens = bin_files[1..]
|
||||
.iter()
|
||||
.map(mmap_file)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
valid_tokens: vec![valid_tokens],
|
||||
train_tokens,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct DatasetRandomIter<'a> {
|
||||
all_tokens: &'a [memmap2::Mmap],
|
||||
tokens: Vec<&'a memmap2::Mmap>,
|
||||
current_tokens: &'a memmap2::Mmap,
|
||||
indexes_in_bytes: Vec<usize>,
|
||||
seq_len: usize,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl<'a> DatasetRandomIter<'a> {
|
||||
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let all_tokens = if valid {
|
||||
&ds.valid_tokens
|
||||
} else {
|
||||
&ds.train_tokens
|
||||
};
|
||||
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
||||
tokens.shuffle(&mut thread_rng());
|
||||
let current_tokens = tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = seq_len * 2;
|
||||
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
Self {
|
||||
all_tokens,
|
||||
tokens,
|
||||
current_tokens,
|
||||
indexes_in_bytes,
|
||||
seq_len,
|
||||
device,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for DatasetRandomIter<'a> {
|
||||
type Item = Result<(Tensor, Tensor)>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let seq_len = self.seq_len;
|
||||
if self.indexes_in_bytes.is_empty() {
|
||||
if self.tokens.is_empty() {
|
||||
self.tokens = self.all_tokens.iter().collect();
|
||||
self.tokens.shuffle(&mut thread_rng());
|
||||
}
|
||||
self.current_tokens = self.tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = self.seq_len * 2;
|
||||
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
}
|
||||
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
||||
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
||||
let mut tokens = vec![0u16; bytes.len() / 2];
|
||||
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
|
||||
return Some(Err(err.into()));
|
||||
}
|
||||
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
|
||||
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
|
||||
let targets = Tensor::new(&tokens[1..], &self.device);
|
||||
Some(candle::error::zip(inputs, targets))
|
||||
}
|
||||
}
|
||||
|
||||
fn valid_loss(
|
||||
dataset: &Dataset,
|
||||
model: &Llama,
|
||||
args: &crate::TrainingCmd,
|
||||
device: &Device,
|
||||
) -> Result<f64> {
|
||||
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
let mut sum_ce = 0f64;
|
||||
let mut cnt = 0usize;
|
||||
for inp_tgt in batch_iter.take(50) {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
sum_ce += loss.to_vec0::<f32>()? as f64;
|
||||
cnt += 1;
|
||||
}
|
||||
Ok(sum_ce / cnt as f64)
|
||||
}
|
||||
|
||||
pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
let dataset = Dataset::new(&args.pretokenized_dir)?;
|
||||
println!(
|
||||
"loaded dataset, train: {} files, valid: {} files",
|
||||
dataset.train_tokens.len(),
|
||||
dataset.valid_tokens.len()
|
||||
);
|
||||
let vb = candle_nn::VarBuilder::zeros(DType::F32, &device);
|
||||
let config = Config::tiny();
|
||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
|
||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
let all_vars = vec![]; // TODO: Propagate the variables from the VarBuilder to here.
|
||||
let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate);
|
||||
for (batch_index, batch) in batch_iter.enumerate() {
|
||||
let (inp, tgt) = batch?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
if batch_index > 0 && batch_index % 100 == 0 {
|
||||
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
||||
// validation loss.
|
||||
let loss = valid_loss(&dataset, &model, args, &device)?;
|
||||
println!("{batch_index} {loss}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user