mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Pre-tokenized evaluation mode for llama2.c. (#291)
This commit is contained in:
@ -215,8 +215,13 @@ struct Args {
|
|||||||
#[arg(long, default_value = "")]
|
#[arg(long, default_value = "")]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
|
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
|
||||||
|
/// script from llama2.c https://github.com/karpathy/llama2.c
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
eval_file: Option<String>,
|
pretokenized_dir: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value_t = 32)]
|
||||||
|
batch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
@ -243,13 +248,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
match args.task {
|
match args.task {
|
||||||
Task::Inference => run_inference(tokenizer, &config_path, args)?,
|
Task::Inference => run_inference(tokenizer, &config_path, args)?,
|
||||||
Task::Evaluation => {
|
Task::Evaluation => run_eval(tokenizer, &config_path, args)?,
|
||||||
if let Some(eval_file) = &args.eval_file {
|
|
||||||
run_eval_file(eval_file.into(), &config_path, args)?
|
|
||||||
} else {
|
|
||||||
run_eval(tokenizer, &config_path, args)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Task::Training => todo!(),
|
Task::Training => todo!(),
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -278,7 +277,6 @@ fn run_eval_file(
|
|||||||
println!("dataset loaded: {} tokens", tokens.len());
|
println!("dataset loaded: {} tokens", tokens.len());
|
||||||
|
|
||||||
let seq_len = model.config.seq_len;
|
let seq_len = model.config.seq_len;
|
||||||
let batch_size = 32;
|
|
||||||
let mut inputs = vec![];
|
let mut inputs = vec![];
|
||||||
let mut targets = vec![];
|
let mut targets = vec![];
|
||||||
for start_idx in (0..tokens.len()).step_by(seq_len) {
|
for start_idx in (0..tokens.len()).step_by(seq_len) {
|
||||||
@ -290,7 +288,7 @@ fn run_eval_file(
|
|||||||
let targets_ = Tensor::new(&tokens[1..], &device)?;
|
let targets_ = Tensor::new(&tokens[1..], &device)?;
|
||||||
inputs.push(inputs_);
|
inputs.push(inputs_);
|
||||||
targets.push(targets_);
|
targets.push(targets_);
|
||||||
if inputs.len() >= batch_size {
|
if inputs.len() >= args.batch_size {
|
||||||
let inp = Tensor::stack(&inputs, 0)?;
|
let inp = Tensor::stack(&inputs, 0)?;
|
||||||
let tgt = Tensor::stack(&targets, 0)?;
|
let tgt = Tensor::stack(&targets, 0)?;
|
||||||
let logits = model.forward(&inp, 0)?;
|
let logits = model.forward(&inp, 0)?;
|
||||||
@ -314,6 +312,8 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args)
|
|||||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Llama::load(vb, &cache, config)?;
|
||||||
|
|
||||||
|
let tokens = match args.pretokenized_dir {
|
||||||
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
|
let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
|
||||||
println!("loading the evaluation dataset from {}", model_id);
|
println!("loading the evaluation dataset from {}", model_id);
|
||||||
@ -323,23 +323,44 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args)
|
|||||||
let file = std::io::BufReader::new(file);
|
let file = std::io::BufReader::new(file);
|
||||||
let mut tokens = vec![];
|
let mut tokens = vec![];
|
||||||
for line in file.lines() {
|
for line in file.lines() {
|
||||||
let line = line?.replace("<|endoftext|>", "");
|
let line = line?.replace("<|endoftext|>", "<s>");
|
||||||
let line = tokenizer.encode(line, false).map_err(E::msg)?;
|
let line = tokenizer.encode(line, false).map_err(E::msg)?;
|
||||||
tokens.push(line.get_ids().to_vec())
|
tokens.push(line.get_ids().to_vec())
|
||||||
}
|
}
|
||||||
let tokens = tokens.concat();
|
tokens.concat()
|
||||||
|
}
|
||||||
|
Some(pretokenized_dir) => {
|
||||||
|
let path = std::path::PathBuf::from(pretokenized_dir).join("data00.bin");
|
||||||
|
let bytes = std::fs::read(path)?;
|
||||||
|
// Tokens are encoded as u16.
|
||||||
|
let mut tokens = vec![0u16; bytes.len() / 2];
|
||||||
|
std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens)?;
|
||||||
|
tokens.into_iter().map(|u| u as u32).collect::<Vec<u32>>()
|
||||||
|
}
|
||||||
|
};
|
||||||
println!("dataset loaded and encoded: {} tokens", tokens.len());
|
println!("dataset loaded and encoded: {} tokens", tokens.len());
|
||||||
let seq_len = 256;
|
|
||||||
|
let seq_len = model.config.seq_len;
|
||||||
|
let mut inputs = vec![];
|
||||||
|
let mut targets = vec![];
|
||||||
for start_idx in (0..tokens.len()).step_by(seq_len) {
|
for start_idx in (0..tokens.len()).step_by(seq_len) {
|
||||||
if start_idx + seq_len + 1 > tokens.len() {
|
if start_idx + seq_len + 1 > tokens.len() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
|
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
|
||||||
let inputs = Tensor::new(&tokens[..seq_len], &device)?.unsqueeze(0)?;
|
let inputs_ = Tensor::new(&tokens[..seq_len], &device)?;
|
||||||
let targets = Tensor::new(&tokens[1..], &device)?;
|
let targets_ = Tensor::new(&tokens[1..], &device)?;
|
||||||
let logits = model.forward(&inputs, 0)?.squeeze(0)?;
|
inputs.push(inputs_);
|
||||||
let loss = candle_nn::loss::cross_entropy(&logits, &targets)?;
|
targets.push(targets_);
|
||||||
println!("{start_idx} {}", loss.to_vec0::<f32>()?);
|
if inputs.len() >= args.batch_size {
|
||||||
|
let inp = Tensor::stack(&inputs, 0)?;
|
||||||
|
let tgt = Tensor::stack(&targets, 0)?;
|
||||||
|
let logits = model.forward(&inp, 0)?;
|
||||||
|
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||||
|
println!("{}", loss.to_vec0::<f32>()?);
|
||||||
|
inputs.clear();
|
||||||
|
targets.clear();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user