mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use the hub models for llama2.c (#284)
This commit is contained in:
@ -8,7 +8,7 @@ extern crate intel_mkl_src;
|
||||
mod model;
|
||||
use clap::Parser;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::{Error as E, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor};
|
||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||
@ -181,38 +181,35 @@ struct Args {
|
||||
|
||||
/// Config file in binary format.
|
||||
#[arg(long)]
|
||||
config: String,
|
||||
config: Option<String>,
|
||||
|
||||
/// Tokenizer config file in binary format.
|
||||
/// Tokenizer config file.
|
||||
#[arg(long)]
|
||||
tokenizer: String,
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
}
|
||||
|
||||
struct Tokenizer {
|
||||
tokens: Vec<String>,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> {
|
||||
let mut tokens = Vec::with_capacity(c.vocab_size);
|
||||
for _token_index in 0..c.vocab_size {
|
||||
let token_len = read_i32(r)?;
|
||||
let mut token = vec![0u8; token_len as usize];
|
||||
r.read_exact(&mut token);
|
||||
tokens.push(String::from_utf8_lossy(&token).into_owned())
|
||||
}
|
||||
Ok(Self { tokens })
|
||||
}
|
||||
#[arg(long, default_value = "karpathy/tinyllamas")]
|
||||
model_id: String,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let mut file = std::fs::File::open(&args.config)?;
|
||||
let config_path = match &args.config {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
println!("loading the model weights from {}", args.model_id);
|
||||
let api = api.model(args.model_id);
|
||||
api.get("stories15M.bin")?
|
||||
}
|
||||
};
|
||||
let mut file = std::fs::File::open(&config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("config: {config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
@ -220,8 +217,16 @@ fn main() -> anyhow::Result<()> {
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, &config)?;
|
||||
|
||||
let mut file = std::fs::File::open(&args.tokenizer)?;
|
||||
let tokenizer = Tokenizer::from_reader(&mut file, &config)?;
|
||||
let tokenizer_path = match &args.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
println!("{tokenizer_path:?}");
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
||||
@ -244,8 +249,15 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
print!("{}", tokenizer.tokens[next_token as usize]);
|
||||
std::io::stdout().flush()?;
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
|
Reference in New Issue
Block a user