diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index d710652f..b627bd3d 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -214,6 +214,9 @@ struct Args { #[arg(long, default_value = "")] prompt: String, + + #[arg(long)] + eval_file: Option, } fn main() -> anyhow::Result<()> { @@ -240,12 +243,66 @@ fn main() -> anyhow::Result<()> { match args.task { Task::Inference => run_inference(tokenizer, &config_path, args)?, - Task::Evaluation => run_eval(tokenizer, &config_path, args)?, + Task::Evaluation => { + if let Some(eval_file) = &args.eval_file { + run_eval_file(eval_file.into(), &config_path, args)? + } else { + run_eval(tokenizer, &config_path, args)? + } + } Task::Training => todo!(), } Ok(()) } +fn run_eval_file( + path: std::path::PathBuf, + config_path: &std::path::PathBuf, + args: Args, +) -> Result<()> { + use std::io::BufRead; + + let device = candle_examples::device(args.cpu)?; + let mut file = std::fs::File::open(config_path)?; + let config = Config::from_reader(&mut file)?; + let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; + let vb = weights.var_builder(&config, &device)?; + let cache = model::Cache::new(false, &config, vb.pp("rot"))?; + let model = Llama::load(vb, &cache, config)?; + + let bytes = std::fs::read(path)?; + // Tokens are encoded as u16. + let mut tokens = vec![0u16; bytes.len() / 2]; + std::io::Cursor::new(bytes).read_u16_into::(&mut tokens)?; + let tokens: Vec = tokens.into_iter().map(|u| u as u32).collect(); + println!("dataset loaded: {} tokens", tokens.len()); + + let seq_len = model.config.seq_len; + let batch_size = 32; + let mut inputs = vec![]; + let mut targets = vec![]; + for start_idx in (0..tokens.len()).step_by(seq_len) { + if start_idx + seq_len + 1 > tokens.len() { + break; + } + let tokens = &tokens[start_idx..start_idx + seq_len + 1]; + let inputs_ = Tensor::new(&tokens[..seq_len], &device)?; + let targets_ = Tensor::new(&tokens[1..], &device)?; + inputs.push(inputs_); + targets.push(targets_); + if inputs.len() >= batch_size { + let inp = Tensor::stack(&inputs, 0)?; + let tgt = Tensor::stack(&targets, 0)?; + let logits = model.forward(&inp, 0)?; + let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; + println!("{}", loss.to_vec0::()?); + inputs.clear(); + targets.clear(); + } + } + Ok(()) +} + fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> { use std::io::BufRead;