From 1a07ff8d176c6c223652bf5b11cd9e9146c24ff3 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 1 Aug 2023 05:36:25 +0100 Subject: [PATCH] Pre-tokenized evaluation mode for llama2.c. (#291) --- candle-examples/examples/llama2-c/main.rs | 81 ++++++++++++++--------- 1 file changed, 51 insertions(+), 30 deletions(-) diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index b627bd3d..ac17aab1 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -215,8 +215,13 @@ struct Args { #[arg(long, default_value = "")] prompt: String, + /// A directory with the pre-tokenized dataset in the format generated by the tinystories.py + /// script from llama2.c https://github.com/karpathy/llama2.c #[arg(long)] - eval_file: Option, + pretokenized_dir: Option, + + #[arg(long, default_value_t = 32)] + batch_size: usize, } fn main() -> anyhow::Result<()> { @@ -243,13 +248,7 @@ fn main() -> anyhow::Result<()> { match args.task { Task::Inference => run_inference(tokenizer, &config_path, args)?, - Task::Evaluation => { - if let Some(eval_file) = &args.eval_file { - run_eval_file(eval_file.into(), &config_path, args)? - } else { - run_eval(tokenizer, &config_path, args)? - } - } + Task::Evaluation => run_eval(tokenizer, &config_path, args)?, Task::Training => todo!(), } Ok(()) @@ -278,7 +277,6 @@ fn run_eval_file( println!("dataset loaded: {} tokens", tokens.len()); let seq_len = model.config.seq_len; - let batch_size = 32; let mut inputs = vec![]; let mut targets = vec![]; for start_idx in (0..tokens.len()).step_by(seq_len) { @@ -290,7 +288,7 @@ fn run_eval_file( let targets_ = Tensor::new(&tokens[1..], &device)?; inputs.push(inputs_); targets.push(targets_); - if inputs.len() >= batch_size { + if inputs.len() >= args.batch_size { let inp = Tensor::stack(&inputs, 0)?; let tgt = Tensor::stack(&targets, 0)?; let logits = model.forward(&inp, 0)?; @@ -314,32 +312,55 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) let cache = model::Cache::new(false, &config, vb.pp("rot"))?; let model = Llama::load(vb, &cache, config)?; - let api = hf_hub::api::sync::Api::new()?; - let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable. - println!("loading the evaluation dataset from {}", model_id); - let api = api.dataset(model_id.to_string()); - let dataset_path = api.get("TinyStories-valid.txt")?; - let file = std::fs::File::open(dataset_path)?; - let file = std::io::BufReader::new(file); - let mut tokens = vec![]; - for line in file.lines() { - let line = line?.replace("<|endoftext|>", ""); - let line = tokenizer.encode(line, false).map_err(E::msg)?; - tokens.push(line.get_ids().to_vec()) - } - let tokens = tokens.concat(); + let tokens = match args.pretokenized_dir { + None => { + let api = hf_hub::api::sync::Api::new()?; + let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable. + println!("loading the evaluation dataset from {}", model_id); + let api = api.dataset(model_id.to_string()); + let dataset_path = api.get("TinyStories-valid.txt")?; + let file = std::fs::File::open(dataset_path)?; + let file = std::io::BufReader::new(file); + let mut tokens = vec![]; + for line in file.lines() { + let line = line?.replace("<|endoftext|>", ""); + let line = tokenizer.encode(line, false).map_err(E::msg)?; + tokens.push(line.get_ids().to_vec()) + } + tokens.concat() + } + Some(pretokenized_dir) => { + let path = std::path::PathBuf::from(pretokenized_dir).join("data00.bin"); + let bytes = std::fs::read(path)?; + // Tokens are encoded as u16. + let mut tokens = vec![0u16; bytes.len() / 2]; + std::io::Cursor::new(bytes).read_u16_into::(&mut tokens)?; + tokens.into_iter().map(|u| u as u32).collect::>() + } + }; println!("dataset loaded and encoded: {} tokens", tokens.len()); - let seq_len = 256; + + let seq_len = model.config.seq_len; + let mut inputs = vec![]; + let mut targets = vec![]; for start_idx in (0..tokens.len()).step_by(seq_len) { if start_idx + seq_len + 1 > tokens.len() { break; } let tokens = &tokens[start_idx..start_idx + seq_len + 1]; - let inputs = Tensor::new(&tokens[..seq_len], &device)?.unsqueeze(0)?; - let targets = Tensor::new(&tokens[1..], &device)?; - let logits = model.forward(&inputs, 0)?.squeeze(0)?; - let loss = candle_nn::loss::cross_entropy(&logits, &targets)?; - println!("{start_idx} {}", loss.to_vec0::()?); + let inputs_ = Tensor::new(&tokens[..seq_len], &device)?; + let targets_ = Tensor::new(&tokens[1..], &device)?; + inputs.push(inputs_); + targets.push(targets_); + if inputs.len() >= args.batch_size { + let inp = Tensor::stack(&inputs, 0)?; + let tgt = Tensor::stack(&targets, 0)?; + let logits = model.forward(&inp, 0)?; + let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; + println!("{}", loss.to_vec0::()?); + inputs.clear(); + targets.clear(); + } } Ok(()) }