mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Merge pull request #19 from LaurentMazare/llama_safetensors
Llama safetensors
This commit is contained in:
@ -27,6 +27,9 @@ anyhow = "1"
|
|||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
||||||
|
tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] }
|
||||||
|
candle-hub = { path = "../candle-hub" }
|
||||||
|
memmap2 = "0.7.1"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["cuda"]
|
default = ["cuda"]
|
||||||
|
@ -15,11 +15,12 @@ use anyhow::{Error as E, Result};
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_hub::{api::Api, Repo, RepoType};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
mod var_store;
|
mod var_store;
|
||||||
use var_store::VarBuilder;
|
mod weights;
|
||||||
|
|
||||||
const CONTEXT_SIZE: usize = 512;
|
const CONTEXT_SIZE: usize = 512;
|
||||||
const START_PROMPT: &str = r"
|
const START_PROMPT: &str = r"
|
||||||
@ -131,9 +132,8 @@ struct Embedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Embedding {
|
impl Embedding {
|
||||||
fn new(mut vb: VarBuilder, vocab_size: usize, n_embd: usize) -> Result<Self> {
|
fn new(embeddings: Tensor) -> Self {
|
||||||
let embeddings = vb.var("weight", (vocab_size, n_embd))?;
|
Self { embeddings }
|
||||||
Ok(Self { embeddings })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||||
@ -145,42 +145,27 @@ impl Embedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Linear {
|
struct Linear {
|
||||||
ws: Tensor,
|
weight: Tensor,
|
||||||
bs: Option<Tensor>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Linear {
|
impl Linear {
|
||||||
#[allow(dead_code)]
|
fn new(weight: Tensor) -> Self {
|
||||||
fn new(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result<Self> {
|
Self { weight }
|
||||||
let ws = vb.var("weight", (in_size, out_size))?;
|
|
||||||
let bs = vb.var("bias", out_size)?;
|
|
||||||
Ok(Self { ws, bs: Some(bs) })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_no_bias(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result<Self> {
|
|
||||||
let ws = vb.var("weight", (in_size, out_size))?;
|
|
||||||
Ok(Self { ws, bs: None })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.matmul(&self.ws.to_dtype(DType::F32)?)?;
|
let x = x.matmul(&self.weight.to_dtype(DType::F32)?.t()?)?;
|
||||||
let y = match &self.bs {
|
Ok(x)
|
||||||
None => x,
|
|
||||||
Some(bs) => x.broadcast_add(&bs.to_dtype(DType::F32)?)?,
|
|
||||||
};
|
|
||||||
Ok(y)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
scale: Tensor,
|
scale: Tensor,
|
||||||
size: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RmsNorm {
|
impl RmsNorm {
|
||||||
fn new(mut vb: VarBuilder, size: usize) -> Result<Self> {
|
fn new(scale: Tensor) -> Self {
|
||||||
let scale = vb.var("scale", &[size])?;
|
Self { scale }
|
||||||
Ok(Self { scale, size })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
@ -188,10 +173,11 @@ impl RmsNorm {
|
|||||||
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||||
|
let size = self.scale.shape().r1()?;
|
||||||
let scale = self
|
let scale = self
|
||||||
.scale
|
.scale
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.broadcast_as((seq_len, self.size))?;
|
.broadcast_as((seq_len, size))?;
|
||||||
Ok((scale * x_normed)?)
|
Ok((scale * x_normed)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -207,17 +193,12 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Mlp {
|
impl Mlp {
|
||||||
fn new(vb: VarBuilder, n_embd: usize) -> Result<Self> {
|
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
||||||
let n_hidden = 8 * n_embd / 3;
|
Self {
|
||||||
let n_hidden = (n_hidden - 1) / 256 * 256 + 256;
|
|
||||||
let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?;
|
|
||||||
let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?;
|
|
||||||
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?;
|
|
||||||
Ok(Self {
|
|
||||||
c_fc1,
|
c_fc1,
|
||||||
c_fc2,
|
c_fc2,
|
||||||
c_proj,
|
c_proj,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
@ -256,10 +237,6 @@ impl Cache {
|
|||||||
let mask: Vec<_> = (0..t)
|
let mask: Vec<_> = (0..t)
|
||||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||||
.collect();
|
.collect();
|
||||||
// Once lower_triangle is available, use the following:
|
|
||||||
//let mask = Tensor::new(1u32, &device)?
|
|
||||||
// .broadcast_as(&[t, t])?
|
|
||||||
// .lower_triangle()?
|
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||||
masks.insert(t, mask.clone());
|
masks.insert(t, mask.clone());
|
||||||
Ok(mask)
|
Ok(mask)
|
||||||
@ -271,21 +248,18 @@ struct CausalSelfAttention {
|
|||||||
c_attn: Linear,
|
c_attn: Linear,
|
||||||
c_proj: Linear,
|
c_proj: Linear,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
n_embd: usize,
|
// n_embd: usize,
|
||||||
cache: Cache,
|
cache: Cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn new(vb: VarBuilder, n_head: usize, n_embd: usize, cache: &Cache) -> Result<Self> {
|
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
|
||||||
let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
|
Self {
|
||||||
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
|
|
||||||
Ok(Self {
|
|
||||||
c_attn,
|
c_attn,
|
||||||
c_proj,
|
c_proj,
|
||||||
n_head,
|
n_head,
|
||||||
n_embd,
|
|
||||||
cache: cache.clone(),
|
cache: cache.clone(),
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
@ -313,7 +287,7 @@ impl CausalSelfAttention {
|
|||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
let (t, c) = x.shape().r2()?;
|
let (t, c) = x.shape().r2()?;
|
||||||
let qkv = self.c_attn.forward(x)?;
|
let qkv = self.c_attn.forward(x)?;
|
||||||
let n_embd = self.n_embd;
|
let n_embd = c;
|
||||||
let q = qkv.narrow(1, 0, n_embd)?;
|
let q = qkv.narrow(1, 0, n_embd)?;
|
||||||
let k = qkv.narrow(1, n_embd, n_embd)?;
|
let k = qkv.narrow(1, n_embd, n_embd)?;
|
||||||
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
|
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
|
||||||
@ -344,17 +318,13 @@ struct Block {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||||
let rms_1 = RmsNorm::new(&vb / "rms_1", config.n_embd)?;
|
Self {
|
||||||
let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd, cache)?;
|
|
||||||
let rms_2 = RmsNorm::new(&vb / "rms_2", config.n_embd)?;
|
|
||||||
let mlp = Mlp::new(&vb / "mlp", config.n_embd)?;
|
|
||||||
Ok(Self {
|
|
||||||
rms_1,
|
rms_1,
|
||||||
attn,
|
attn,
|
||||||
rms_2,
|
rms_2,
|
||||||
mlp,
|
mlp,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
@ -372,23 +342,13 @@ struct Llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Llama {
|
impl Llama {
|
||||||
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
||||||
let lm_head = Linear::new_no_bias(&vb / "lm_head", config.n_embd, config.vocab_size)?;
|
Self {
|
||||||
let wte = Embedding::new(
|
|
||||||
&vb / "transformer" / "wte",
|
|
||||||
config.vocab_size,
|
|
||||||
config.n_embd,
|
|
||||||
)?;
|
|
||||||
let blocks = (0..config.n_layer)
|
|
||||||
.map(|i| Block::new(&vb / "transformer" / "h" / i, cache, config))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let ln_f = RmsNorm::new(&vb / "transformer" / "ln_f", config.n_embd)?;
|
|
||||||
Ok(Self {
|
|
||||||
wte,
|
wte,
|
||||||
blocks,
|
blocks,
|
||||||
ln_f,
|
ln_f,
|
||||||
lm_head,
|
lm_head,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
@ -434,6 +394,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Use npy instead of safetensors
|
||||||
|
#[arg(long)]
|
||||||
|
npy: bool,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long, default_value_t = 1.0)]
|
#[arg(long, default_value_t = 1.0)]
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
@ -443,8 +407,9 @@ struct Args {
|
|||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
#[tokio::main]
|
||||||
use rand::prelude::*;
|
async fn main() -> Result<()> {
|
||||||
|
//use rand::prelude::*;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
@ -453,38 +418,44 @@ fn main() -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
println!("loading tokenizer config");
|
let api = Api::new()?;
|
||||||
let tokenizer = Tokenizer::from_file("llama-tokenizer.json").map_err(E::msg)?;
|
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||||
|
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||||
|
println!("Filename {tokenizer_filename:?}");
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(START_PROMPT, true)
|
.encode(START_PROMPT, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
let weight_path = std::path::Path::new("llama.npz");
|
let mut filenames = vec![];
|
||||||
let weights = if weight_path.exists() {
|
for rfilename in [
|
||||||
println!("loading weights from {weight_path:?}");
|
"model-00001-of-00002.safetensors",
|
||||||
let start_load = std::time::Instant::now();
|
"model-00002-of-00002.safetensors",
|
||||||
let tensors = Tensor::read_npz(weight_path)?;
|
] {
|
||||||
println!("loaded weights in {:?}", start_load.elapsed());
|
let filename = api.get(&repo, rfilename).await?;
|
||||||
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
filenames.push(filename);
|
||||||
Some(tensors)
|
}
|
||||||
} else {
|
|
||||||
println!("cannot find {weight_path:?}, using zero weights");
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let vb = VarBuilder::new::<f32>(&device, weights);
|
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let cache = Cache::new(&device);
|
let cache = Cache::new(&device);
|
||||||
let llama = Llama::new(vb, &cache, &config)?;
|
let start = std::time::Instant::now();
|
||||||
|
let llama = if args.npy {
|
||||||
|
println!("building the model (NPY)");
|
||||||
|
Llama::load_npy(&device, &filenames, &cache, &config)?
|
||||||
|
} else {
|
||||||
|
println!("building the model (SF)");
|
||||||
|
Llama::load(&device, &filenames, &cache, &config)?
|
||||||
|
};
|
||||||
|
println!("Loaded in {:?}", start.elapsed());
|
||||||
|
|
||||||
println!("pre-computing the positional embeddings");
|
println!("pre-computing the positional embeddings");
|
||||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
let mut rng = thread_rng();
|
//let mut rng = thread_rng();
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
@ -493,8 +464,20 @@ fn main() -> Result<()> {
|
|||||||
let logits = llama.forward(&input, &freqs_cis)?;
|
let logits = llama.forward(&input, &freqs_cis)?;
|
||||||
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
let next_token = logits_v
|
||||||
let next_token = distr.sample(&mut rng) as u32;
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
|
||||||
|
if &val_max > val {
|
||||||
|
(idx_max, val_max)
|
||||||
|
} else {
|
||||||
|
(idx, *val)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.0 as u32;
|
||||||
|
// let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||||
|
|
||||||
|
// let next_token = distr.sample(&mut rng) as u32;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
println!("> {:?}", start_gen.elapsed());
|
println!("> {:?}", start_gen.elapsed());
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
use super::*;
|
||||||
use candle::{DType, Device, Result, Shape, Tensor, WithDType};
|
use candle::{DType, Device, Result, Shape, Tensor, WithDType};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
@ -40,22 +42,15 @@ impl VarBuilder {
|
|||||||
self.vars.borrow().len()
|
self.vars.borrow().len()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn var<S: Into<Shape>>(&mut self, s: &str, shape: S) -> Result<Tensor> {
|
pub fn var(&self, s: &str) -> Result<Tensor> {
|
||||||
let shape = shape.into();
|
|
||||||
let path = format!("{}.{s}", self.path.join("."));
|
let path = format!("{}.{s}", self.path.join("."));
|
||||||
let mut vars = self.vars.borrow_mut();
|
|
||||||
let parameter = match self.tensors.as_ref() {
|
let parameter = match self.tensors.as_ref() {
|
||||||
None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?,
|
None => panic!("Cannot find tensors"),
|
||||||
Some(tensors) => match tensors.get(&path) {
|
Some(tensors) => match tensors.get(&path) {
|
||||||
Some(tensor) => tensor.to_device(&self.default_device)?,
|
Some(tensor) => tensor.to_device(&self.default_device)?,
|
||||||
None => panic!("cannot find tensor for {path}"),
|
None => panic!("cannot find tensor for {path}"),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
vars.push(NamedVar {
|
|
||||||
path,
|
|
||||||
dtype: self.default_dtype,
|
|
||||||
shape,
|
|
||||||
});
|
|
||||||
Ok(parameter)
|
Ok(parameter)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,3 +85,88 @@ impl<S: ToString> std::ops::Div<S> for VarBuilder {
|
|||||||
&self / rhs
|
&self / rhs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Embedding {
|
||||||
|
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embeddings = vb.var("weight")?;
|
||||||
|
Ok(Self { embeddings })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||||
|
let weight = vb.var("weight")?.t()?;
|
||||||
|
Ok(Self { weight })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RmsNorm {
|
||||||
|
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||||
|
let scale = vb.var("scale")?;
|
||||||
|
Ok(Self::new(scale))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CausalSelfAttention {
|
||||||
|
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||||
|
let c_attn = Linear::load_npy(&vb / "c_attn")?;
|
||||||
|
let c_proj = Linear::load_npy(&vb / "c_proj")?;
|
||||||
|
Ok(Self::new(c_attn, c_proj, config.n_head, cache))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mlp {
|
||||||
|
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||||
|
let c_fc1 = Linear::load_npy(&vb / "c_fc1")?;
|
||||||
|
let c_fc2 = Linear::load_npy(&vb / "c_fc2")?;
|
||||||
|
let c_proj = Linear::load_npy(&vb / "c_proj")?;
|
||||||
|
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Block {
|
||||||
|
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||||
|
let attn = CausalSelfAttention::load_npy(&vb / "attn", cache, config)?;
|
||||||
|
let mlp = Mlp::load_npy(&vb / "mlp")?;
|
||||||
|
let input_layernorm = RmsNorm::load_npy(&vb / "rms_1")?;
|
||||||
|
let post_attention_layernorm = RmsNorm::load_npy(&vb / "rms_2")?;
|
||||||
|
Ok(Self::new(
|
||||||
|
input_layernorm,
|
||||||
|
attn,
|
||||||
|
post_attention_layernorm,
|
||||||
|
mlp,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Llama {
|
||||||
|
pub fn load_npy(
|
||||||
|
device: &Device,
|
||||||
|
_filenames: &[PathBuf],
|
||||||
|
cache: &Cache,
|
||||||
|
config: &Config,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight_path = std::path::Path::new("/data/llama.npz");
|
||||||
|
let weights = if weight_path.exists() {
|
||||||
|
println!("loading weights from {weight_path:?}");
|
||||||
|
let start_load = std::time::Instant::now();
|
||||||
|
let tensors = Tensor::read_npz(weight_path)?;
|
||||||
|
println!("loaded weights in {:?}", start_load.elapsed());
|
||||||
|
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
||||||
|
Some(tensors)
|
||||||
|
} else {
|
||||||
|
println!("cannot find {weight_path:?}, using zero weights");
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let vb = VarBuilder::new::<f32>(device, weights);
|
||||||
|
|
||||||
|
let wte = Embedding::load_npy(&vb / "transformer" / "wte")?;
|
||||||
|
let lm_head = Linear::load_npy(&vb / "lm_head")?;
|
||||||
|
let norm = RmsNorm::load_npy(&vb / "transformer" / "ln_f")?;
|
||||||
|
let blocks: Vec<_> = (0..config.n_layer)
|
||||||
|
.map(|i| Block::load_npy(&vb / "transformer" / "h" / i, cache, config).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
164
candle-core/examples/llama/weights.rs
Normal file
164
candle-core/examples/llama/weights.rs
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
use super::*;
|
||||||
|
use candle::{Device, Result, Tensor};
|
||||||
|
use half::f16;
|
||||||
|
use memmap2::MmapOptions;
|
||||||
|
use safetensors::{
|
||||||
|
tensor::{Dtype, TensorView},
|
||||||
|
SafeTensors,
|
||||||
|
};
|
||||||
|
use std::fs::File;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||||
|
match view.dtype() {
|
||||||
|
Dtype::F16 => {
|
||||||
|
let v = view.data();
|
||||||
|
if (v.as_ptr() as usize) % 2 == 0 {
|
||||||
|
// SAFETY This is safe because we just checked that this
|
||||||
|
// was correctly aligned.
|
||||||
|
let data: &[f16] =
|
||||||
|
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
|
||||||
|
Tensor::from_slice(data, view.shape(), device)
|
||||||
|
} else {
|
||||||
|
let mut c = Vec::with_capacity(v.len() / 2);
|
||||||
|
let mut i = 0;
|
||||||
|
while i < v.len() {
|
||||||
|
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
|
||||||
|
i += 2;
|
||||||
|
}
|
||||||
|
Tensor::from_slice(&c, view.shape(), device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dt => todo!("Unhandled dtype {dt:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct VarBuilder<'a> {
|
||||||
|
routing: HashMap<String, usize>,
|
||||||
|
safetensors: Vec<SafeTensors<'a>>,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> VarBuilder<'a> {
|
||||||
|
pub fn new(safetensors: Vec<SafeTensors<'a>>, device: Device) -> Self {
|
||||||
|
let mut routing = HashMap::new();
|
||||||
|
for (index, sf) in safetensors.iter().enumerate() {
|
||||||
|
for k in sf.names() {
|
||||||
|
routing.insert(k.to_string(), index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
safetensors,
|
||||||
|
device,
|
||||||
|
routing,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, tensor_name: &str) -> Result<Tensor> {
|
||||||
|
// Unwrap or 0 just to let the proper error flow.
|
||||||
|
let index = self.routing.get(tensor_name).unwrap_or(&0);
|
||||||
|
let view = self.safetensors[*index].tensor(tensor_name).unwrap();
|
||||||
|
let tensor = convert(view, &self.device)?;
|
||||||
|
Ok(tensor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
|
let weight = vb.get(&format!("{prefix}.weight"))?;
|
||||||
|
Ok(Self::new(weight))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_multi(prefixes: &[&str], vb: &VarBuilder) -> Result<Self> {
|
||||||
|
let weights: Vec<_> = prefixes
|
||||||
|
.iter()
|
||||||
|
.map(|p| vb.get(&format!("{p}.weight")).unwrap())
|
||||||
|
.collect();
|
||||||
|
let weight = Tensor::cat(&weights, 0)?;
|
||||||
|
Ok(Self::new(weight))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RmsNorm {
|
||||||
|
fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
|
let scale = vb.get(&format!("{prefix}.weight"))?;
|
||||||
|
Ok(Self::new(scale))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CausalSelfAttention {
|
||||||
|
fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||||
|
let c_attn = Linear::load_multi(
|
||||||
|
&[
|
||||||
|
&format!("{prefix}.q_proj"),
|
||||||
|
&format!("{prefix}.k_proj"),
|
||||||
|
&format!("{prefix}.v_proj"),
|
||||||
|
],
|
||||||
|
vb,
|
||||||
|
)?;
|
||||||
|
let o_proj = Linear::load(&format!("{prefix}.o_proj"), vb)?;
|
||||||
|
Ok(Self::new(c_attn, o_proj, config.n_head, cache))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mlp {
|
||||||
|
fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
|
let c_fc1 = Linear::load(&format!("{prefix}.gate_proj"), vb)?;
|
||||||
|
let c_fc2 = Linear::load(&format!("{prefix}.up_proj"), vb)?;
|
||||||
|
let c_proj = Linear::load(&format!("{prefix}.down_proj"), vb)?;
|
||||||
|
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Block {
|
||||||
|
fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||||
|
let attn = CausalSelfAttention::load(&format!("{prefix}.self_attn"), vb, cache, config)?;
|
||||||
|
let mlp = Mlp::load(&format!("{prefix}.mlp"), vb)?;
|
||||||
|
let input_layernorm = RmsNorm::load(&format!("{prefix}.input_layernorm"), vb)?;
|
||||||
|
let post_attention_layernorm =
|
||||||
|
RmsNorm::load(&format!("{prefix}.post_attention_layernorm"), vb)?;
|
||||||
|
Ok(Self::new(
|
||||||
|
input_layernorm,
|
||||||
|
attn,
|
||||||
|
post_attention_layernorm,
|
||||||
|
mlp,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Llama {
|
||||||
|
pub fn load(
|
||||||
|
device: &Device,
|
||||||
|
filenames: &[PathBuf],
|
||||||
|
cache: &Cache,
|
||||||
|
config: &Config,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let handles: Vec<_> = filenames
|
||||||
|
.iter()
|
||||||
|
.map(|f| {
|
||||||
|
let file = File::open(f).unwrap();
|
||||||
|
unsafe { MmapOptions::new().map(&file).unwrap() }
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let tensors: Vec<_> = handles
|
||||||
|
.iter()
|
||||||
|
.map(|h| {
|
||||||
|
let tensors = SafeTensors::deserialize(h).unwrap();
|
||||||
|
tensors
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let vb = VarBuilder::new(tensors, device.clone());
|
||||||
|
|
||||||
|
let embedding = vb.get("model.embed_tokens.weight")?;
|
||||||
|
let wte = Embedding::new(embedding);
|
||||||
|
let lm_head = Linear::load("lm_head", &vb)?;
|
||||||
|
let norm = RmsNorm::load("model.norm", &vb)?;
|
||||||
|
let blocks: Vec<_> = (0..config.n_layer)
|
||||||
|
.map(|i| Block::load(&format!("model.layers.{i}"), &vb, cache, config).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||||
|
}
|
||||||
|
}
|
@ -378,6 +378,12 @@ impl Api {
|
|||||||
let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures));
|
let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures));
|
||||||
let filename = temp_filename();
|
let filename = temp_filename();
|
||||||
|
|
||||||
|
// Create the file and set everything properly
|
||||||
|
tokio::fs::File::create(&filename)
|
||||||
|
.await?
|
||||||
|
.set_len(length as u64)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let chunk_size = self.chunk_size;
|
let chunk_size = self.chunk_size;
|
||||||
for start in (0..length).step_by(chunk_size) {
|
for start in (0..length).step_by(chunk_size) {
|
||||||
let url = url.to_string();
|
let url = url.to_string();
|
||||||
@ -440,7 +446,6 @@ impl Api {
|
|||||||
let range = format!("bytes={start}-{stop}");
|
let range = format!("bytes={start}-{stop}");
|
||||||
let mut file = tokio::fs::OpenOptions::new()
|
let mut file = tokio::fs::OpenOptions::new()
|
||||||
.write(true)
|
.write(true)
|
||||||
.create(true)
|
|
||||||
.open(filename)
|
.open(filename)
|
||||||
.await?;
|
.await?;
|
||||||
file.seek(SeekFrom::Start(start as u64)).await?;
|
file.seek(SeekFrom::Start(start as u64)).await?;
|
||||||
|
@ -53,7 +53,11 @@ impl Cache {
|
|||||||
let commit_hash = std::fs::read_to_string(commit_path).ok()?;
|
let commit_hash = std::fs::read_to_string(commit_path).ok()?;
|
||||||
let mut pointer_path = self.pointer_path(repo, &commit_hash);
|
let mut pointer_path = self.pointer_path(repo, &commit_hash);
|
||||||
pointer_path.push(filename);
|
pointer_path.push(filename);
|
||||||
Some(pointer_path)
|
if pointer_path.exists() {
|
||||||
|
Some(pointer_path)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a reference in the cache directory that points branches to the correct
|
/// Creates a reference in the cache directory that points branches to the correct
|
||||||
@ -146,7 +150,12 @@ impl Repo {
|
|||||||
|
|
||||||
/// The normalized folder nameof the repo within the cache directory
|
/// The normalized folder nameof the repo within the cache directory
|
||||||
pub fn folder_name(&self) -> String {
|
pub fn folder_name(&self) -> String {
|
||||||
self.repo_id.replace('/', "--")
|
let prefix = match self.repo_type {
|
||||||
|
RepoType::Model => "models",
|
||||||
|
RepoType::Dataset => "datasets",
|
||||||
|
RepoType::Space => "spaces",
|
||||||
|
};
|
||||||
|
format!("{prefix}--{}", self.repo_id).replace('/', "--")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The revision
|
/// The revision
|
||||||
|
Reference in New Issue
Block a user