From 9ae1f6afeecca7b424b0943d591809481dc88dbc Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 31 Jul 2023 17:22:14 +0100 Subject: [PATCH] Add an eval mode to llama2-c (#288) * Add an eval mode to llama2-c. * Encode line by line. * Get the eval to run. --- candle-examples/examples/llama2-c/main.rs | 95 +++++++++++++++++----- candle-examples/examples/llama2-c/model.rs | 27 +++--- 2 files changed, 87 insertions(+), 35 deletions(-) diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index dca85ead..65641b3c 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -6,7 +6,7 @@ extern crate intel_mkl_src; mod model; -use clap::Parser; +use clap::{Parser, ValueEnum}; use anyhow::{Error as E, Result}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -14,6 +14,7 @@ use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor}; use candle_nn::{Embedding, Linear, VarBuilder}; use candle_transformers::generation::LogitsProcessor; use std::io::Write; +use tokenizers::Tokenizer; use model::{Config, Llama}; @@ -172,9 +173,20 @@ impl TransformerWeights { } } +#[derive(ValueEnum, Debug, Clone)] +enum Task { + Inference, + Evaluation, + Training, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { + /// The task to be performed, inference, training or evaluation. + #[clap(value_enum, default_value_t = Task::Inference)] + task: Task, + /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, @@ -205,26 +217,16 @@ struct Args { } fn main() -> anyhow::Result<()> { - use tokenizers::Tokenizer; - let args = Args::parse(); - let device = candle_examples::device(args.cpu)?; let config_path = match &args.config { Some(config) => std::path::PathBuf::from(config), None => { let api = hf_hub::api::sync::Api::new()?; println!("loading the model weights from {}", args.model_id); - let api = api.model(args.model_id); + let api = api.model(args.model_id.clone()); api.get(&args.which_model)? } }; - let mut file = std::fs::File::open(&config_path)?; - let config = Config::from_reader(&mut file)?; - println!("config: {config:?}"); - let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; - let vb = weights.var_builder(&config, &device)?; - let cache = model::Cache::new(true, &config, vb.pp("rot"))?; - let model = Llama::load(vb, &cache, &config)?; let tokenizer_path = match &args.tokenizer { Some(config) => std::path::PathBuf::from(config), @@ -234,9 +236,66 @@ fn main() -> anyhow::Result<()> { api.get("tokenizer.json")? } }; - println!("{tokenizer_path:?}"); let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?; + match args.task { + Task::Inference => run_inference(tokenizer, &config_path, args)?, + Task::Evaluation => run_eval(tokenizer, &config_path, args)?, + Task::Training => todo!(), + } + Ok(()) +} + +fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> { + use std::io::BufRead; + + let device = candle_examples::device(args.cpu)?; + let mut file = std::fs::File::open(config_path)?; + let config = Config::from_reader(&mut file)?; + let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; + let vb = weights.var_builder(&config, &device)?; + let cache = model::Cache::new(false, &config, vb.pp("rot"))?; + let model = Llama::load(vb, &cache, config)?; + + let api = hf_hub::api::sync::Api::new()?; + let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable. + println!("loading the evaluation dataset from {}", model_id); + let api = api.dataset(model_id.to_string()); + let dataset_path = api.get("TinyStories-valid.txt")?; + let file = std::fs::File::open(dataset_path)?; + let file = std::io::BufReader::new(file); + let mut tokens = vec![]; + for line in file.lines() { + let line = tokenizer.encode(line?, false).map_err(E::msg)?; + tokens.push(line.get_ids().to_vec()) + } + let tokens = tokens.concat(); + println!("dataset loaded and encoded: {} tokens", tokens.len()); + let seq_len = 256; + for start_idx in (0..tokens.len()).step_by(seq_len) { + if start_idx + seq_len + 1 > tokens.len() { + break; + } + let tokens = &tokens[start_idx..start_idx + seq_len + 1]; + let inputs = Tensor::new(&tokens[..seq_len], &device)?.unsqueeze(0)?; + let targets = Tensor::new(&tokens[1..], &device)?; + let logits = model.forward(&inputs, 0)?.squeeze(0)?; + let loss = candle_nn::loss::cross_entropy(&logits, &targets)?; + println!("{start_idx} {}", loss.to_vec0::()?); + } + Ok(()) +} + +fn run_inference(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> { + let device = candle_examples::device(args.cpu)?; + + let mut file = std::fs::File::open(config_path)?; + let config = Config::from_reader(&mut file)?; + let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; + let vb = weights.var_builder(&config, &device)?; + let cache = model::Cache::new(true, &config, vb.pp("rot"))?; + let model = Llama::load(vb, &cache, config)?; + println!("starting the inference loop"); let mut logits_processor = LogitsProcessor::new(299792458, args.temperature); let mut index_pos = 0; @@ -250,19 +309,15 @@ fn main() -> anyhow::Result<()> { let start_gen = std::time::Instant::now(); for index in 0.. { - if tokens.len() >= config.seq_len { + if tokens.len() >= model.config.seq_len { break; } let start_gen = std::time::Instant::now(); - let context_size = if cache.use_kv_cache && index > 0 { - 1 - } else { - tokens.len() - }; + let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let logits = model.forward(&input, index_pos)?; - let logits = logits.squeeze(0)?; + let logits = logits.i((0, logits.dim(1)? - 1))?; index_pos += ctxt.len(); let next_token = logits_processor.sample(&logits)?; diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index fbeb4038..a92367e6 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -286,18 +286,10 @@ pub struct Llama { blocks: Vec, ln_f: RmsNorm, lm_head: Linear, + pub config: Config, } impl Llama { - fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { - Self { - wte, - blocks, - ln_f, - lm_head, - } - } - pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { let (_b_sz, seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; @@ -305,18 +297,23 @@ impl Llama { x = block.forward(&x, index_pos, block_idx)?; } let x = self.ln_f.forward(&x)?; - let x = x.i((.., seq_len - 1, ..))?; let logits = self.lm_head.forward(&x)?; logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { - let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result { + let wte = embedding(&cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) - .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) + .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap()) .collect(); - Ok(Self::new(wte, blocks, norm, lm_head)) + Ok(Self { + wte, + blocks, + ln_f, + lm_head, + config: cfg, + }) } }