From 94a43faaca91f2a6c4fe256569bbdc8c47429d97 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 31 Jul 2023 12:51:14 +0100 Subject: [PATCH] Use the hub models for llama2.c (#284) --- candle-examples/examples/llama2-c/main.rs | 62 ++++++++++++++--------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 9b6d1316..20a6267c 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -8,7 +8,7 @@ extern crate intel_mkl_src; mod model; use clap::Parser; -use anyhow::Result; +use anyhow::{Error as E, Result}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor}; use candle_nn::{Embedding, Linear, VarBuilder}; @@ -181,38 +181,35 @@ struct Args { /// Config file in binary format. #[arg(long)] - config: String, + config: Option, - /// Tokenizer config file in binary format. + /// Tokenizer config file. #[arg(long)] - tokenizer: String, + tokenizer: Option, /// The temperature used to generate samples. #[arg(long)] temperature: Option, -} -struct Tokenizer { - tokens: Vec, -} - -impl Tokenizer { - fn from_reader(r: &mut R, c: &Config) -> Result { - let mut tokens = Vec::with_capacity(c.vocab_size); - for _token_index in 0..c.vocab_size { - let token_len = read_i32(r)?; - let mut token = vec![0u8; token_len as usize]; - r.read_exact(&mut token); - tokens.push(String::from_utf8_lossy(&token).into_owned()) - } - Ok(Self { tokens }) - } + #[arg(long, default_value = "karpathy/tinyllamas")] + model_id: String, } fn main() -> anyhow::Result<()> { + use tokenizers::Tokenizer; + let args = Args::parse(); let device = candle_examples::device(args.cpu)?; - let mut file = std::fs::File::open(&args.config)?; + let config_path = match &args.config { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + println!("loading the model weights from {}", args.model_id); + let api = api.model(args.model_id); + api.get("stories15M.bin")? + } + }; + let mut file = std::fs::File::open(&config_path)?; let config = Config::from_reader(&mut file)?; println!("config: {config:?}"); let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; @@ -220,8 +217,16 @@ fn main() -> anyhow::Result<()> { let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let model = Llama::load(vb, &cache, &config)?; - let mut file = std::fs::File::open(&args.tokenizer)?; - let tokenizer = Tokenizer::from_reader(&mut file, &config)?; + let tokenizer_path = match &args.tokenizer { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("hf-internal-testing/llama-tokenizer".to_string()); + api.get("tokenizer.json")? + } + }; + println!("{tokenizer_path:?}"); + let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?; println!("starting the inference loop"); let mut logits_processor = LogitsProcessor::new(299792458, args.temperature); @@ -244,8 +249,15 @@ fn main() -> anyhow::Result<()> { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); - print!("{}", tokenizer.tokens[next_token as usize]); - std::io::stdout().flush()?; + // Extracting the last token as a string is complicated, here we just apply some simple + // heuristics as it seems to work well enough for this example. See the following for more + // details: + // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141 + if let Some(text) = tokenizer.id_to_token(next_token) { + let text = text.replace('▁', " ").replace("<0x0A>", "\n"); + print!("{text}"); + std::io::stdout().flush()?; + } } let dt = start_gen.elapsed(); println!(