mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Support for MQA for llama v2. (#205)
* Support for MQA for llama v2. * More llama-v2. * Move the rotary embedding precomputation in the cache. * Add a v2 flag. * Use the hf model.
This commit is contained in:
@ -15,7 +15,7 @@ extern crate intel_mkl_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
@ -76,23 +76,6 @@ Whate'er it bodes, henceforward will I bear
|
||||
Upon my target three fair-shining suns.
|
||||
";
|
||||
|
||||
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
|
||||
let n_elem = config.n_embd / config.n_head;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1];
|
||||
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
|
||||
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
|
||||
Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], D::Minus1)?)
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -127,6 +110,12 @@ struct Args {
|
||||
/// Use f32 computations rather than f16.
|
||||
#[arg(long)]
|
||||
use_f32: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
v2: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -136,7 +125,7 @@ fn main() -> Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config = Config::config_7b();
|
||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
|
||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||
let (llama, tokenizer_filename) = match args.npy {
|
||||
Some(filename) => {
|
||||
@ -146,8 +135,15 @@ fn main() -> Result<()> {
|
||||
}
|
||||
None => {
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||
println!("loading the model weights");
|
||||
let model_id = args.model_id.unwrap_or_else(|| {
|
||||
if args.v2 {
|
||||
"meta-llama/Llama-2-7b-hf".to_string()
|
||||
} else {
|
||||
"Narsil/amall-7b".to_string()
|
||||
}
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let repo = Repo::new(model_id, RepoType::Model);
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
@ -180,8 +176,6 @@ fn main() -> Result<()> {
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("pre-computing the positional embeddings");
|
||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut new_tokens = vec![];
|
||||
@ -196,12 +190,7 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let freqs_cis = if cache.use_kv_cache {
|
||||
freqs_cis.narrow(1, index_pos, ctxt.len())?
|
||||
} else {
|
||||
freqs_cis.clone()
|
||||
};
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
let logits = llama.forward(&input, index_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
index_pos += ctxt.len();
|
||||
|
||||
|
Reference in New Issue
Block a user