mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add an eval mode to llama2-c (#288)
* Add an eval mode to llama2-c. * Encode line by line. * Get the eval to run.
This commit is contained in:
@ -6,7 +6,7 @@
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod model;
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
@ -14,6 +14,7 @@ use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor};
|
||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Config, Llama};
|
||||
|
||||
@ -172,9 +173,20 @@ impl TransformerWeights {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum Task {
|
||||
Inference,
|
||||
Evaluation,
|
||||
Training,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The task to be performed, inference, training or evaluation.
|
||||
#[clap(value_enum, default_value_t = Task::Inference)]
|
||||
task: Task,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
@ -205,26 +217,16 @@ struct Args {
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config_path = match &args.config {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
println!("loading the model weights from {}", args.model_id);
|
||||
let api = api.model(args.model_id);
|
||||
let api = api.model(args.model_id.clone());
|
||||
api.get(&args.which_model)?
|
||||
}
|
||||
};
|
||||
let mut file = std::fs::File::open(&config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("config: {config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, &config)?;
|
||||
|
||||
let tokenizer_path = match &args.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
@ -234,9 +236,66 @@ fn main() -> anyhow::Result<()> {
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
println!("{tokenizer_path:?}");
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
||||
|
||||
match args.task {
|
||||
Task::Inference => run_inference(tokenizer, &config_path, args)?,
|
||||
Task::Evaluation => run_eval(tokenizer, &config_path, args)?,
|
||||
Task::Training => todo!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
|
||||
println!("loading the evaluation dataset from {}", model_id);
|
||||
let api = api.dataset(model_id.to_string());
|
||||
let dataset_path = api.get("TinyStories-valid.txt")?;
|
||||
let file = std::fs::File::open(dataset_path)?;
|
||||
let file = std::io::BufReader::new(file);
|
||||
let mut tokens = vec![];
|
||||
for line in file.lines() {
|
||||
let line = tokenizer.encode(line?, false).map_err(E::msg)?;
|
||||
tokens.push(line.get_ids().to_vec())
|
||||
}
|
||||
let tokens = tokens.concat();
|
||||
println!("dataset loaded and encoded: {} tokens", tokens.len());
|
||||
let seq_len = 256;
|
||||
for start_idx in (0..tokens.len()).step_by(seq_len) {
|
||||
if start_idx + seq_len + 1 > tokens.len() {
|
||||
break;
|
||||
}
|
||||
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
|
||||
let inputs = Tensor::new(&tokens[..seq_len], &device)?.unsqueeze(0)?;
|
||||
let targets = Tensor::new(&tokens[1..], &device)?;
|
||||
let logits = model.forward(&inputs, 0)?.squeeze(0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits, &targets)?;
|
||||
println!("{start_idx} {}", loss.to_vec0::<f32>()?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_inference(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
||||
let mut index_pos = 0;
|
||||
@ -250,19 +309,15 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
if tokens.len() >= config.seq_len {
|
||||
if tokens.len() >= model.config.seq_len {
|
||||
break;
|
||||
}
|
||||
let start_gen = std::time::Instant::now();
|
||||
let context_size = if cache.use_kv_cache && index > 0 {
|
||||
1
|
||||
} else {
|
||||
tokens.len()
|
||||
};
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
|
@ -286,18 +286,10 @@ pub struct Llama {
|
||||
blocks: Vec<Block>,
|
||||
ln_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
||||
Self {
|
||||
wte,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
@ -305,18 +297,23 @@ impl Llama {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
||||
let wte = embedding(&cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
|
||||
.collect();
|
||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||
Ok(Self {
|
||||
wte,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
config: cfg,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user