mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Evaluate on the pre-tokenized file. (#290)
This commit is contained in:
@ -214,6 +214,9 @@ struct Args {
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
prompt: String,
|
||||
|
||||
#[arg(long)]
|
||||
eval_file: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
@ -240,12 +243,66 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
match args.task {
|
||||
Task::Inference => run_inference(tokenizer, &config_path, args)?,
|
||||
Task::Evaluation => run_eval(tokenizer, &config_path, args)?,
|
||||
Task::Evaluation => {
|
||||
if let Some(eval_file) = &args.eval_file {
|
||||
run_eval_file(eval_file.into(), &config_path, args)?
|
||||
} else {
|
||||
run_eval(tokenizer, &config_path, args)?
|
||||
}
|
||||
}
|
||||
Task::Training => todo!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_eval_file(
|
||||
path: std::path::PathBuf,
|
||||
config_path: &std::path::PathBuf,
|
||||
args: Args,
|
||||
) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
let bytes = std::fs::read(path)?;
|
||||
// Tokens are encoded as u16.
|
||||
let mut tokens = vec![0u16; bytes.len() / 2];
|
||||
std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens)?;
|
||||
let tokens: Vec<u32> = tokens.into_iter().map(|u| u as u32).collect();
|
||||
println!("dataset loaded: {} tokens", tokens.len());
|
||||
|
||||
let seq_len = model.config.seq_len;
|
||||
let batch_size = 32;
|
||||
let mut inputs = vec![];
|
||||
let mut targets = vec![];
|
||||
for start_idx in (0..tokens.len()).step_by(seq_len) {
|
||||
if start_idx + seq_len + 1 > tokens.len() {
|
||||
break;
|
||||
}
|
||||
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
|
||||
let inputs_ = Tensor::new(&tokens[..seq_len], &device)?;
|
||||
let targets_ = Tensor::new(&tokens[1..], &device)?;
|
||||
inputs.push(inputs_);
|
||||
targets.push(targets_);
|
||||
if inputs.len() >= batch_size {
|
||||
let inp = Tensor::stack(&inputs, 0)?;
|
||||
let tgt = Tensor::stack(&targets, 0)?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
println!("{}", loss.to_vec0::<f32>()?);
|
||||
inputs.clear();
|
||||
targets.clear();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
|
Reference in New Issue
Block a user