mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Final updates -> moving to deterministic for easier comparison.
This commit is contained in:
@ -19,9 +19,7 @@ use candle_hub::{api::Api, Repo, RepoType};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
// mod var_store;
|
||||
// use var_store::VarBuilder;
|
||||
|
||||
mod var_store;
|
||||
mod weights;
|
||||
|
||||
const CONTEXT_SIZE: usize = 512;
|
||||
@ -196,11 +194,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
|
||||
impl Mlp {
|
||||
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
||||
// let n_hidden = 8 * n_embd / 3;
|
||||
// let n_hidden = (n_hidden - 1) / 256 * 256 + 256;
|
||||
// let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?;
|
||||
// let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?;
|
||||
// let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?;
|
||||
Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
@ -244,10 +237,6 @@ impl Cache {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||
.collect();
|
||||
// Once lower_triangle is available, use the followig:
|
||||
//let mask = Tensor::new(1u32, &device)?
|
||||
// .broadcast_as(&[t, t])?
|
||||
// .lower_triangle()?
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
@ -265,13 +254,10 @@ struct CausalSelfAttention {
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
|
||||
// let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
|
||||
// let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
|
||||
Self {
|
||||
c_attn,
|
||||
c_proj,
|
||||
n_head,
|
||||
// n_embd,
|
||||
cache: cache.clone(),
|
||||
}
|
||||
}
|
||||
@ -408,6 +394,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use npy instead of safetensors
|
||||
#[arg(long)]
|
||||
npy: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
temperature: f64,
|
||||
@ -419,7 +409,7 @@ struct Args {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
use rand::prelude::*;
|
||||
//use rand::prelude::*;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
@ -431,6 +421,7 @@ async fn main() -> Result<()> {
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||
println!("Filename {tokenizer_filename:?}");
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let mut tokens = tokenizer
|
||||
.encode(START_PROMPT, true)
|
||||
@ -446,30 +437,25 @@ async fn main() -> Result<()> {
|
||||
let filename = api.get(&repo, rfilename).await?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
// let weight_path = std::path::Path::new("llama.npz");
|
||||
// let weights = if weight_path.exists() {
|
||||
// println!("loading weights from {weight_path:?}");
|
||||
// let start_load = std::time::Instant::now();
|
||||
// let tensors = Tensor::read_npz(weight_path)?;
|
||||
// println!("loaded weights in {:?}", start_load.elapsed());
|
||||
// let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
||||
// Some(tensors)
|
||||
// } else {
|
||||
// println!("cannot find {weight_path:?}, using zero weights");
|
||||
// None
|
||||
// };
|
||||
// let vb = VarBuilder::new::<f32>(&device, weights);
|
||||
|
||||
println!("building the model");
|
||||
let config = Config::config_7b();
|
||||
let cache = Cache::new(&device);
|
||||
let llama = Llama::load(&device, &filenames, &cache, &config)?;
|
||||
let start = std::time::Instant::now();
|
||||
let llama = if args.npy {
|
||||
println!("building the model (NPY)");
|
||||
Llama::load_npy(&device, &filenames, &cache, &config)?
|
||||
} else {
|
||||
println!("building the model (SF)");
|
||||
Llama::load(&device, &filenames, &cache, &config)?
|
||||
};
|
||||
println!("Loaded in {:?}", start.elapsed());
|
||||
|
||||
println!("pre-computing the positional embeddings");
|
||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||
println!("starting the inference loop");
|
||||
let mut new_tokens = vec![];
|
||||
let mut rng = thread_rng();
|
||||
//let mut rng = thread_rng();
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..args.sample_len {
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||
@ -477,8 +463,20 @@ async fn main() -> Result<()> {
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
let next_token = distr.sample(&mut rng) as u32;
|
||||
let next_token = logits_v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
|
||||
if &val_max > val {
|
||||
(idx_max, val_max)
|
||||
} else {
|
||||
(idx, *val)
|
||||
}
|
||||
})
|
||||
.0 as u32;
|
||||
// let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
|
||||
// let next_token = distr.sample(&mut rng) as u32;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
println!(
|
||||
|
@ -1,5 +1,7 @@
|
||||
use super::*;
|
||||
use candle::{DType, Device, Result, Shape, Tensor, WithDType};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(dead_code)]
|
||||
@ -40,22 +42,15 @@ impl VarBuilder {
|
||||
self.vars.borrow().len()
|
||||
}
|
||||
|
||||
pub fn var<S: Into<Shape>>(&mut self, s: &str, shape: S) -> Result<Tensor> {
|
||||
let shape = shape.into();
|
||||
pub fn var(&self, s: &str) -> Result<Tensor> {
|
||||
let path = format!("{}.{s}", self.path.join("."));
|
||||
let mut vars = self.vars.borrow_mut();
|
||||
let parameter = match self.tensors.as_ref() {
|
||||
None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?,
|
||||
None => panic!("Cannot find tensors"),
|
||||
Some(tensors) => match tensors.get(&path) {
|
||||
Some(tensor) => tensor.to_device(&self.default_device)?,
|
||||
None => panic!("cannot find tensor for {path}"),
|
||||
},
|
||||
};
|
||||
vars.push(NamedVar {
|
||||
path,
|
||||
dtype: self.default_dtype,
|
||||
shape,
|
||||
});
|
||||
Ok(parameter)
|
||||
}
|
||||
|
||||
@ -90,3 +85,88 @@ impl<S: ToString> std::ops::Div<S> for VarBuilder {
|
||||
&self / rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.var("weight")?;
|
||||
Ok(Self { embeddings })
|
||||
}
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.var("weight")?.t()?;
|
||||
Ok(Self { weight })
|
||||
}
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||
let scale = vb.var("scale")?;
|
||||
Ok(Self::new(scale))
|
||||
}
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||
let c_attn = Linear::load_npy(&vb / "c_attn")?;
|
||||
let c_proj = Linear::load_npy(&vb / "c_proj")?;
|
||||
Ok(Self::new(c_attn, c_proj, config.n_head, cache))
|
||||
}
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||
let c_fc1 = Linear::load_npy(&vb / "c_fc1")?;
|
||||
let c_fc2 = Linear::load_npy(&vb / "c_fc2")?;
|
||||
let c_proj = Linear::load_npy(&vb / "c_proj")?;
|
||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||
}
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load_npy(&vb / "attn", cache, config)?;
|
||||
let mlp = Mlp::load_npy(&vb / "mlp")?;
|
||||
let input_layernorm = RmsNorm::load_npy(&vb / "rms_1")?;
|
||||
let post_attention_layernorm = RmsNorm::load_npy(&vb / "rms_2")?;
|
||||
Ok(Self::new(
|
||||
input_layernorm,
|
||||
attn,
|
||||
post_attention_layernorm,
|
||||
mlp,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
pub fn load_npy(
|
||||
device: &Device,
|
||||
_filenames: &[PathBuf],
|
||||
cache: &Cache,
|
||||
config: &Config,
|
||||
) -> Result<Self> {
|
||||
let weight_path = std::path::Path::new("/data/llama.npz");
|
||||
let weights = if weight_path.exists() {
|
||||
println!("loading weights from {weight_path:?}");
|
||||
let start_load = std::time::Instant::now();
|
||||
let tensors = Tensor::read_npz(weight_path)?;
|
||||
println!("loaded weights in {:?}", start_load.elapsed());
|
||||
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
||||
Some(tensors)
|
||||
} else {
|
||||
println!("cannot find {weight_path:?}, using zero weights");
|
||||
None
|
||||
};
|
||||
let vb = VarBuilder::new::<f32>(device, weights);
|
||||
|
||||
let wte = Embedding::load_npy(&vb / "transformer" / "wte")?;
|
||||
let lm_head = Linear::load_npy(&vb / "lm_head")?;
|
||||
let norm = RmsNorm::load_npy(&vb / "transformer" / "ln_f")?;
|
||||
let blocks: Vec<_> = (0..config.n_layer)
|
||||
.map(|i| Block::load_npy(&vb / "transformer" / "h" / i, cache, config).unwrap())
|
||||
.collect();
|
||||
|
||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user