Use the hub models for llama2.c (#284)

This commit is contained in:
Laurent Mazare
2023-07-31 12:51:14 +01:00
committed by GitHub
parent 62a9b03715
commit 94a43faaca

View File

@ -8,7 +8,7 @@ extern crate intel_mkl_src;
mod model; mod model;
use clap::Parser; use clap::Parser;
use anyhow::Result; use anyhow::{Error as E, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor}; use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor};
use candle_nn::{Embedding, Linear, VarBuilder}; use candle_nn::{Embedding, Linear, VarBuilder};
@ -181,38 +181,35 @@ struct Args {
/// Config file in binary format. /// Config file in binary format.
#[arg(long)] #[arg(long)]
config: String, config: Option<String>,
/// Tokenizer config file in binary format. /// Tokenizer config file.
#[arg(long)] #[arg(long)]
tokenizer: String, tokenizer: Option<String>,
/// The temperature used to generate samples. /// The temperature used to generate samples.
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
}
struct Tokenizer { #[arg(long, default_value = "karpathy/tinyllamas")]
tokens: Vec<String>, model_id: String,
}
impl Tokenizer {
fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> {
let mut tokens = Vec::with_capacity(c.vocab_size);
for _token_index in 0..c.vocab_size {
let token_len = read_i32(r)?;
let mut token = vec![0u8; token_len as usize];
r.read_exact(&mut token);
tokens.push(String::from_utf8_lossy(&token).into_owned())
}
Ok(Self { tokens })
}
} }
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse(); let args = Args::parse();
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let mut file = std::fs::File::open(&args.config)?; let config_path = match &args.config {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
println!("loading the model weights from {}", args.model_id);
let api = api.model(args.model_id);
api.get("stories15M.bin")?
}
};
let mut file = std::fs::File::open(&config_path)?;
let config = Config::from_reader(&mut file)?; let config = Config::from_reader(&mut file)?;
println!("config: {config:?}"); println!("config: {config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
@ -220,8 +217,16 @@ fn main() -> anyhow::Result<()> {
let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, &config)?; let model = Llama::load(vb, &cache, &config)?;
let mut file = std::fs::File::open(&args.tokenizer)?; let tokenizer_path = match &args.tokenizer {
let tokenizer = Tokenizer::from_reader(&mut file, &config)?; Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
api.get("tokenizer.json")?
}
};
println!("{tokenizer_path:?}");
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
println!("starting the inference loop"); println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature); let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
@ -244,8 +249,15 @@ fn main() -> anyhow::Result<()> {
let next_token = logits_processor.sample(&logits)?; let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
print!("{}", tokenizer.tokens[next_token as usize]); // Extracting the last token as a string is complicated, here we just apply some simple
std::io::stdout().flush()?; // heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
} }
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
println!( println!(