mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Final updates -> moving to deterministic for easier comparison.
This commit is contained in:
@ -19,9 +19,7 @@ use candle_hub::{api::Api, Repo, RepoType};
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
// mod var_store;
|
mod var_store;
|
||||||
// use var_store::VarBuilder;
|
|
||||||
|
|
||||||
mod weights;
|
mod weights;
|
||||||
|
|
||||||
const CONTEXT_SIZE: usize = 512;
|
const CONTEXT_SIZE: usize = 512;
|
||||||
@ -196,11 +194,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
|||||||
|
|
||||||
impl Mlp {
|
impl Mlp {
|
||||||
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
||||||
// let n_hidden = 8 * n_embd / 3;
|
|
||||||
// let n_hidden = (n_hidden - 1) / 256 * 256 + 256;
|
|
||||||
// let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?;
|
|
||||||
// let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?;
|
|
||||||
// let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?;
|
|
||||||
Self {
|
Self {
|
||||||
c_fc1,
|
c_fc1,
|
||||||
c_fc2,
|
c_fc2,
|
||||||
@ -244,10 +237,6 @@ impl Cache {
|
|||||||
let mask: Vec<_> = (0..t)
|
let mask: Vec<_> = (0..t)
|
||||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||||
.collect();
|
.collect();
|
||||||
// Once lower_triangle is available, use the followig:
|
|
||||||
//let mask = Tensor::new(1u32, &device)?
|
|
||||||
// .broadcast_as(&[t, t])?
|
|
||||||
// .lower_triangle()?
|
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||||
masks.insert(t, mask.clone());
|
masks.insert(t, mask.clone());
|
||||||
Ok(mask)
|
Ok(mask)
|
||||||
@ -265,13 +254,10 @@ struct CausalSelfAttention {
|
|||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
|
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
|
||||||
// let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
|
|
||||||
// let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
|
|
||||||
Self {
|
Self {
|
||||||
c_attn,
|
c_attn,
|
||||||
c_proj,
|
c_proj,
|
||||||
n_head,
|
n_head,
|
||||||
// n_embd,
|
|
||||||
cache: cache.clone(),
|
cache: cache.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -408,6 +394,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Use npy instead of safetensors
|
||||||
|
#[arg(long)]
|
||||||
|
npy: bool,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long, default_value_t = 1.0)]
|
#[arg(long, default_value_t = 1.0)]
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
@ -419,7 +409,7 @@ struct Args {
|
|||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
use rand::prelude::*;
|
//use rand::prelude::*;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
@ -431,6 +421,7 @@ async fn main() -> Result<()> {
|
|||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||||
|
println!("Filename {tokenizer_filename:?}");
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(START_PROMPT, true)
|
.encode(START_PROMPT, true)
|
||||||
@ -446,30 +437,25 @@ async fn main() -> Result<()> {
|
|||||||
let filename = api.get(&repo, rfilename).await?;
|
let filename = api.get(&repo, rfilename).await?;
|
||||||
filenames.push(filename);
|
filenames.push(filename);
|
||||||
}
|
}
|
||||||
// let weight_path = std::path::Path::new("llama.npz");
|
|
||||||
// let weights = if weight_path.exists() {
|
|
||||||
// println!("loading weights from {weight_path:?}");
|
|
||||||
// let start_load = std::time::Instant::now();
|
|
||||||
// let tensors = Tensor::read_npz(weight_path)?;
|
|
||||||
// println!("loaded weights in {:?}", start_load.elapsed());
|
|
||||||
// let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
|
||||||
// Some(tensors)
|
|
||||||
// } else {
|
|
||||||
// println!("cannot find {weight_path:?}, using zero weights");
|
|
||||||
// None
|
|
||||||
// };
|
|
||||||
// let vb = VarBuilder::new::<f32>(&device, weights);
|
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let cache = Cache::new(&device);
|
let cache = Cache::new(&device);
|
||||||
let llama = Llama::load(&device, &filenames, &cache, &config)?;
|
let start = std::time::Instant::now();
|
||||||
|
let llama = if args.npy {
|
||||||
|
println!("building the model (NPY)");
|
||||||
|
Llama::load_npy(&device, &filenames, &cache, &config)?
|
||||||
|
} else {
|
||||||
|
println!("building the model (SF)");
|
||||||
|
Llama::load(&device, &filenames, &cache, &config)?
|
||||||
|
};
|
||||||
|
println!("Loaded in {:?}", start.elapsed());
|
||||||
|
|
||||||
println!("pre-computing the positional embeddings");
|
println!("pre-computing the positional embeddings");
|
||||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
let mut rng = thread_rng();
|
//let mut rng = thread_rng();
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||||
@ -477,8 +463,20 @@ async fn main() -> Result<()> {
|
|||||||
let logits = llama.forward(&input, &freqs_cis)?;
|
let logits = llama.forward(&input, &freqs_cis)?;
|
||||||
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
let next_token = logits_v
|
||||||
let next_token = distr.sample(&mut rng) as u32;
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
|
||||||
|
if &val_max > val {
|
||||||
|
(idx_max, val_max)
|
||||||
|
} else {
|
||||||
|
(idx, *val)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.0 as u32;
|
||||||
|
// let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||||
|
|
||||||
|
// let next_token = distr.sample(&mut rng) as u32;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
println!(
|
println!(
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
use super::*;
|
||||||
use candle::{DType, Device, Result, Shape, Tensor, WithDType};
|
use candle::{DType, Device, Result, Shape, Tensor, WithDType};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
@ -40,22 +42,15 @@ impl VarBuilder {
|
|||||||
self.vars.borrow().len()
|
self.vars.borrow().len()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn var<S: Into<Shape>>(&mut self, s: &str, shape: S) -> Result<Tensor> {
|
pub fn var(&self, s: &str) -> Result<Tensor> {
|
||||||
let shape = shape.into();
|
|
||||||
let path = format!("{}.{s}", self.path.join("."));
|
let path = format!("{}.{s}", self.path.join("."));
|
||||||
let mut vars = self.vars.borrow_mut();
|
|
||||||
let parameter = match self.tensors.as_ref() {
|
let parameter = match self.tensors.as_ref() {
|
||||||
None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?,
|
None => panic!("Cannot find tensors"),
|
||||||
Some(tensors) => match tensors.get(&path) {
|
Some(tensors) => match tensors.get(&path) {
|
||||||
Some(tensor) => tensor.to_device(&self.default_device)?,
|
Some(tensor) => tensor.to_device(&self.default_device)?,
|
||||||
None => panic!("cannot find tensor for {path}"),
|
None => panic!("cannot find tensor for {path}"),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
vars.push(NamedVar {
|
|
||||||
path,
|
|
||||||
dtype: self.default_dtype,
|
|
||||||
shape,
|
|
||||||
});
|
|
||||||
Ok(parameter)
|
Ok(parameter)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,3 +85,88 @@ impl<S: ToString> std::ops::Div<S> for VarBuilder {
|
|||||||
&self / rhs
|
&self / rhs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Embedding {
|
||||||
|
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embeddings = vb.var("weight")?;
|
||||||
|
Ok(Self { embeddings })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||||
|
let weight = vb.var("weight")?.t()?;
|
||||||
|
Ok(Self { weight })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RmsNorm {
|
||||||
|
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||||
|
let scale = vb.var("scale")?;
|
||||||
|
Ok(Self::new(scale))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CausalSelfAttention {
|
||||||
|
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||||
|
let c_attn = Linear::load_npy(&vb / "c_attn")?;
|
||||||
|
let c_proj = Linear::load_npy(&vb / "c_proj")?;
|
||||||
|
Ok(Self::new(c_attn, c_proj, config.n_head, cache))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mlp {
|
||||||
|
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||||
|
let c_fc1 = Linear::load_npy(&vb / "c_fc1")?;
|
||||||
|
let c_fc2 = Linear::load_npy(&vb / "c_fc2")?;
|
||||||
|
let c_proj = Linear::load_npy(&vb / "c_proj")?;
|
||||||
|
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Block {
|
||||||
|
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
||||||
|
let attn = CausalSelfAttention::load_npy(&vb / "attn", cache, config)?;
|
||||||
|
let mlp = Mlp::load_npy(&vb / "mlp")?;
|
||||||
|
let input_layernorm = RmsNorm::load_npy(&vb / "rms_1")?;
|
||||||
|
let post_attention_layernorm = RmsNorm::load_npy(&vb / "rms_2")?;
|
||||||
|
Ok(Self::new(
|
||||||
|
input_layernorm,
|
||||||
|
attn,
|
||||||
|
post_attention_layernorm,
|
||||||
|
mlp,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Llama {
|
||||||
|
pub fn load_npy(
|
||||||
|
device: &Device,
|
||||||
|
_filenames: &[PathBuf],
|
||||||
|
cache: &Cache,
|
||||||
|
config: &Config,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight_path = std::path::Path::new("/data/llama.npz");
|
||||||
|
let weights = if weight_path.exists() {
|
||||||
|
println!("loading weights from {weight_path:?}");
|
||||||
|
let start_load = std::time::Instant::now();
|
||||||
|
let tensors = Tensor::read_npz(weight_path)?;
|
||||||
|
println!("loaded weights in {:?}", start_load.elapsed());
|
||||||
|
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
||||||
|
Some(tensors)
|
||||||
|
} else {
|
||||||
|
println!("cannot find {weight_path:?}, using zero weights");
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let vb = VarBuilder::new::<f32>(device, weights);
|
||||||
|
|
||||||
|
let wte = Embedding::load_npy(&vb / "transformer" / "wte")?;
|
||||||
|
let lm_head = Linear::load_npy(&vb / "lm_head")?;
|
||||||
|
let norm = RmsNorm::load_npy(&vb / "transformer" / "ln_f")?;
|
||||||
|
let blocks: Vec<_> = (0..config.n_layer)
|
||||||
|
.map(|i| Block::load_npy(&vb / "transformer" / "h" / i, cache, config).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user