mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use the hub models for llama2.c (#284)
This commit is contained in:
@ -8,7 +8,7 @@ extern crate intel_mkl_src;
|
|||||||
mod model;
|
mod model;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::{Error as E, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor};
|
use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor};
|
||||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||||
@ -181,38 +181,35 @@ struct Args {
|
|||||||
|
|
||||||
/// Config file in binary format.
|
/// Config file in binary format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
config: String,
|
config: Option<String>,
|
||||||
|
|
||||||
/// Tokenizer config file in binary format.
|
/// Tokenizer config file.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer: String,
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
}
|
|
||||||
|
|
||||||
struct Tokenizer {
|
#[arg(long, default_value = "karpathy/tinyllamas")]
|
||||||
tokens: Vec<String>,
|
model_id: String,
|
||||||
}
|
|
||||||
|
|
||||||
impl Tokenizer {
|
|
||||||
fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> {
|
|
||||||
let mut tokens = Vec::with_capacity(c.vocab_size);
|
|
||||||
for _token_index in 0..c.vocab_size {
|
|
||||||
let token_len = read_i32(r)?;
|
|
||||||
let mut token = vec![0u8; token_len as usize];
|
|
||||||
r.read_exact(&mut token);
|
|
||||||
tokens.push(String::from_utf8_lossy(&token).into_owned())
|
|
||||||
}
|
|
||||||
Ok(Self { tokens })
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let mut file = std::fs::File::open(&args.config)?;
|
let config_path = match &args.config {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
println!("loading the model weights from {}", args.model_id);
|
||||||
|
let api = api.model(args.model_id);
|
||||||
|
api.get("stories15M.bin")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut file = std::fs::File::open(&config_path)?;
|
||||||
let config = Config::from_reader(&mut file)?;
|
let config = Config::from_reader(&mut file)?;
|
||||||
println!("config: {config:?}");
|
println!("config: {config:?}");
|
||||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||||
@ -220,8 +217,16 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, &config)?;
|
let model = Llama::load(vb, &cache, &config)?;
|
||||||
|
|
||||||
let mut file = std::fs::File::open(&args.tokenizer)?;
|
let tokenizer_path = match &args.tokenizer {
|
||||||
let tokenizer = Tokenizer::from_reader(&mut file, &config)?;
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("{tokenizer_path:?}");
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
||||||
@ -244,8 +249,15 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
print!("{}", tokenizer.tokens[next_token as usize]);
|
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||||
std::io::stdout().flush()?;
|
// heuristics as it seems to work well enough for this example. See the following for more
|
||||||
|
// details:
|
||||||
|
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||||
|
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||||
|
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
print!("{text}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
|
Reference in New Issue
Block a user