mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Merge pull request #19 from LaurentMazare/llama_safetensors
Llama safetensors
This commit is contained in:
@ -15,11 +15,12 @@ use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_hub::{api::Api, Repo, RepoType};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
mod var_store;
|
||||
use var_store::VarBuilder;
|
||||
mod weights;
|
||||
|
||||
const CONTEXT_SIZE: usize = 512;
|
||||
const START_PROMPT: &str = r"
|
||||
@ -131,9 +132,8 @@ struct Embedding {
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(mut vb: VarBuilder, vocab_size: usize, n_embd: usize) -> Result<Self> {
|
||||
let embeddings = vb.var("weight", (vocab_size, n_embd))?;
|
||||
Ok(Self { embeddings })
|
||||
fn new(embeddings: Tensor) -> Self {
|
||||
Self { embeddings }
|
||||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
@ -145,42 +145,27 @@ impl Embedding {
|
||||
}
|
||||
|
||||
struct Linear {
|
||||
ws: Tensor,
|
||||
bs: Option<Tensor>,
|
||||
weight: Tensor,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
#[allow(dead_code)]
|
||||
fn new(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result<Self> {
|
||||
let ws = vb.var("weight", (in_size, out_size))?;
|
||||
let bs = vb.var("bias", out_size)?;
|
||||
Ok(Self { ws, bs: Some(bs) })
|
||||
}
|
||||
|
||||
fn new_no_bias(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result<Self> {
|
||||
let ws = vb.var("weight", (in_size, out_size))?;
|
||||
Ok(Self { ws, bs: None })
|
||||
fn new(weight: Tensor) -> Self {
|
||||
Self { weight }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.ws.to_dtype(DType::F32)?)?;
|
||||
let y = match &self.bs {
|
||||
None => x,
|
||||
Some(bs) => x.broadcast_add(&bs.to_dtype(DType::F32)?)?,
|
||||
};
|
||||
Ok(y)
|
||||
let x = x.matmul(&self.weight.to_dtype(DType::F32)?.t()?)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn new(mut vb: VarBuilder, size: usize) -> Result<Self> {
|
||||
let scale = vb.var("scale", &[size])?;
|
||||
Ok(Self { scale, size })
|
||||
fn new(scale: Tensor) -> Self {
|
||||
Self { scale }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
@ -188,10 +173,11 @@ impl RmsNorm {
|
||||
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let size = self.scale.shape().r1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((seq_len, self.size))?;
|
||||
.broadcast_as((seq_len, size))?;
|
||||
Ok((scale * x_normed)?)
|
||||
}
|
||||
}
|
||||
@ -207,17 +193,12 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(vb: VarBuilder, n_embd: usize) -> Result<Self> {
|
||||
let n_hidden = 8 * n_embd / 3;
|
||||
let n_hidden = (n_hidden - 1) / 256 * 256 + 256;
|
||||
let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?;
|
||||
let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?;
|
||||
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?;
|
||||
Ok(Self {
|
||||
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
||||
Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
@ -256,10 +237,6 @@ impl Cache {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||
.collect();
|
||||
// Once lower_triangle is available, use the following:
|
||||
//let mask = Tensor::new(1u32, &device)?
|
||||
// .broadcast_as(&[t, t])?
|
||||
// .lower_triangle()?
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
@ -271,21 +248,18 @@ struct CausalSelfAttention {
|
||||
c_attn: Linear,
|
||||
c_proj: Linear,
|
||||
n_head: usize,
|
||||
n_embd: usize,
|
||||
// n_embd: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn new(vb: VarBuilder, n_head: usize, n_embd: usize, cache: &Cache) -> Result<Self> {
|
||||
let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
|
||||
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
|
||||
Ok(Self {
|
||||
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
|
||||
Self {
|
||||
c_attn,
|
||||
c_proj,
|
||||
n_head,
|
||||
n_embd,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
@ -313,7 +287,7 @@ impl CausalSelfAttention {
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
let (t, c) = x.shape().r2()?;
|
||||
let qkv = self.c_attn.forward(x)?;
|
||||
let n_embd = self.n_embd;
|
||||
let n_embd = c;
|
||||
let q = qkv.narrow(1, 0, n_embd)?;
|
||||
let k = qkv.narrow(1, n_embd, n_embd)?;
|
||||
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
|
||||
@ -344,17 +318,13 @@ struct Block {
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||
let rms_1 = RmsNorm::new(&vb / "rms_1", config.n_embd)?;
|
||||
let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd, cache)?;
|
||||
let rms_2 = RmsNorm::new(&vb / "rms_2", config.n_embd)?;
|
||||
let mlp = Mlp::new(&vb / "mlp", config.n_embd)?;
|
||||
Ok(Self {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
rms_1,
|
||||
attn,
|
||||
rms_2,
|
||||
mlp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
@ -372,23 +342,13 @@ struct Llama {
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||
let lm_head = Linear::new_no_bias(&vb / "lm_head", config.n_embd, config.vocab_size)?;
|
||||
let wte = Embedding::new(
|
||||
&vb / "transformer" / "wte",
|
||||
config.vocab_size,
|
||||
config.n_embd,
|
||||
)?;
|
||||
let blocks = (0..config.n_layer)
|
||||
.map(|i| Block::new(&vb / "transformer" / "h" / i, cache, config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = RmsNorm::new(&vb / "transformer" / "ln_f", config.n_embd)?;
|
||||
Ok(Self {
|
||||
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
||||
Self {
|
||||
wte,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
@ -434,6 +394,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use npy instead of safetensors
|
||||
#[arg(long)]
|
||||
npy: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
temperature: f64,
|
||||
@ -443,8 +407,9 @@ struct Args {
|
||||
sample_len: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use rand::prelude::*;
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
//use rand::prelude::*;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
@ -453,38 +418,44 @@ fn main() -> Result<()> {
|
||||
} else {
|
||||
Device::new_cuda(0)?
|
||||
};
|
||||
println!("loading tokenizer config");
|
||||
let tokenizer = Tokenizer::from_file("llama-tokenizer.json").map_err(E::msg)?;
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||
println!("Filename {tokenizer_filename:?}");
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let mut tokens = tokenizer
|
||||
.encode(START_PROMPT, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let weight_path = std::path::Path::new("llama.npz");
|
||||
let weights = if weight_path.exists() {
|
||||
println!("loading weights from {weight_path:?}");
|
||||
let start_load = std::time::Instant::now();
|
||||
let tensors = Tensor::read_npz(weight_path)?;
|
||||
println!("loaded weights in {:?}", start_load.elapsed());
|
||||
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
||||
Some(tensors)
|
||||
} else {
|
||||
println!("cannot find {weight_path:?}, using zero weights");
|
||||
None
|
||||
};
|
||||
let vb = VarBuilder::new::<f32>(&device, weights);
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename).await?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
|
||||
println!("building the model");
|
||||
let config = Config::config_7b();
|
||||
let cache = Cache::new(&device);
|
||||
let llama = Llama::new(vb, &cache, &config)?;
|
||||
let start = std::time::Instant::now();
|
||||
let llama = if args.npy {
|
||||
println!("building the model (NPY)");
|
||||
Llama::load_npy(&device, &filenames, &cache, &config)?
|
||||
} else {
|
||||
println!("building the model (SF)");
|
||||
Llama::load(&device, &filenames, &cache, &config)?
|
||||
};
|
||||
println!("Loaded in {:?}", start.elapsed());
|
||||
|
||||
println!("pre-computing the positional embeddings");
|
||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||
println!("starting the inference loop");
|
||||
let mut new_tokens = vec![];
|
||||
let mut rng = thread_rng();
|
||||
//let mut rng = thread_rng();
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
@ -493,8 +464,20 @@ fn main() -> Result<()> {
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
let next_token = distr.sample(&mut rng) as u32;
|
||||
let next_token = logits_v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
|
||||
if &val_max > val {
|
||||
(idx_max, val_max)
|
||||
} else {
|
||||
(idx, *val)
|
||||
}
|
||||
})
|
||||
.0 as u32;
|
||||
// let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
|
||||
// let next_token = distr.sample(&mut rng) as u32;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
|
Reference in New Issue
Block a user