Final updates -> moving to deterministic for easier comparison.

This commit is contained in:
Ubuntu
2023-06-28 12:27:03 +00:00
parent 926fffa0b7
commit ece3ec6167
2 changed files with 120 additions and 42 deletions

View File

@ -19,9 +19,7 @@ use candle_hub::{api::Api, Repo, RepoType};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
// mod var_store; mod var_store;
// use var_store::VarBuilder;
mod weights; mod weights;
const CONTEXT_SIZE: usize = 512; const CONTEXT_SIZE: usize = 512;
@ -196,11 +194,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
impl Mlp { impl Mlp {
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self { fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
// let n_hidden = 8 * n_embd / 3;
// let n_hidden = (n_hidden - 1) / 256 * 256 + 256;
// let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?;
// let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?;
// let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?;
Self { Self {
c_fc1, c_fc1,
c_fc2, c_fc2,
@ -244,10 +237,6 @@ impl Cache {
let mask: Vec<_> = (0..t) let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j > i))) .flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
.collect(); .collect();
// Once lower_triangle is available, use the followig:
//let mask = Tensor::new(1u32, &device)?
// .broadcast_as(&[t, t])?
// .lower_triangle()?
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone()); masks.insert(t, mask.clone());
Ok(mask) Ok(mask)
@ -265,13 +254,10 @@ struct CausalSelfAttention {
impl CausalSelfAttention { impl CausalSelfAttention {
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self { fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
// let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
// let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
Self { Self {
c_attn, c_attn,
c_proj, c_proj,
n_head, n_head,
// n_embd,
cache: cache.clone(), cache: cache.clone(),
} }
} }
@ -408,6 +394,10 @@ struct Args {
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// Use npy instead of safetensors
#[arg(long)]
npy: bool,
/// The temperature used to generate samples. /// The temperature used to generate samples.
#[arg(long, default_value_t = 1.0)] #[arg(long, default_value_t = 1.0)]
temperature: f64, temperature: f64,
@ -419,7 +409,7 @@ struct Args {
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
use rand::prelude::*; //use rand::prelude::*;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
let args = Args::parse(); let args = Args::parse();
@ -431,6 +421,7 @@ async fn main() -> Result<()> {
let api = Api::new()?; let api = Api::new()?;
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
println!("Filename {tokenizer_filename:?}");
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mut tokens = tokenizer let mut tokens = tokenizer
.encode(START_PROMPT, true) .encode(START_PROMPT, true)
@ -446,30 +437,25 @@ async fn main() -> Result<()> {
let filename = api.get(&repo, rfilename).await?; let filename = api.get(&repo, rfilename).await?;
filenames.push(filename); filenames.push(filename);
} }
// let weight_path = std::path::Path::new("llama.npz");
// let weights = if weight_path.exists() {
// println!("loading weights from {weight_path:?}");
// let start_load = std::time::Instant::now();
// let tensors = Tensor::read_npz(weight_path)?;
// println!("loaded weights in {:?}", start_load.elapsed());
// let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
// Some(tensors)
// } else {
// println!("cannot find {weight_path:?}, using zero weights");
// None
// };
// let vb = VarBuilder::new::<f32>(&device, weights);
println!("building the model"); println!("building the model");
let config = Config::config_7b(); let config = Config::config_7b();
let cache = Cache::new(&device); let cache = Cache::new(&device);
let llama = Llama::load(&device, &filenames, &cache, &config)?; let start = std::time::Instant::now();
let llama = if args.npy {
println!("building the model (NPY)");
Llama::load_npy(&device, &filenames, &cache, &config)?
} else {
println!("building the model (SF)");
Llama::load(&device, &filenames, &cache, &config)?
};
println!("Loaded in {:?}", start.elapsed());
println!("pre-computing the positional embeddings"); println!("pre-computing the positional embeddings");
let freqs_cis = precompute_freqs_cis(&config, &device)?; let freqs_cis = precompute_freqs_cis(&config, &device)?;
println!("starting the inference loop"); println!("starting the inference loop");
let mut new_tokens = vec![]; let mut new_tokens = vec![];
let mut rng = thread_rng(); //let mut rng = thread_rng();
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0..args.sample_len { for index in 0..args.sample_len {
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
@ -477,8 +463,20 @@ async fn main() -> Result<()> {
let logits = llama.forward(&input, &freqs_cis)?; let logits = llama.forward(&input, &freqs_cis)?;
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?; let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
let logits_v: Vec<f32> = prs.to_vec1()?; let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?; let next_token = logits_v
let next_token = distr.sample(&mut rng) as u32; .iter()
.enumerate()
.fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
if &val_max > val {
(idx_max, val_max)
} else {
(idx, *val)
}
})
.0 as u32;
// let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
// let next_token = distr.sample(&mut rng) as u32;
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); new_tokens.push(next_token);
println!( println!(

View File

@ -1,5 +1,7 @@
use super::*;
use candle::{DType, Device, Result, Shape, Tensor, WithDType}; use candle::{DType, Device, Result, Shape, Tensor, WithDType};
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
#[allow(dead_code)] #[allow(dead_code)]
@ -40,22 +42,15 @@ impl VarBuilder {
self.vars.borrow().len() self.vars.borrow().len()
} }
pub fn var<S: Into<Shape>>(&mut self, s: &str, shape: S) -> Result<Tensor> { pub fn var(&self, s: &str) -> Result<Tensor> {
let shape = shape.into();
let path = format!("{}.{s}", self.path.join(".")); let path = format!("{}.{s}", self.path.join("."));
let mut vars = self.vars.borrow_mut();
let parameter = match self.tensors.as_ref() { let parameter = match self.tensors.as_ref() {
None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?, None => panic!("Cannot find tensors"),
Some(tensors) => match tensors.get(&path) { Some(tensors) => match tensors.get(&path) {
Some(tensor) => tensor.to_device(&self.default_device)?, Some(tensor) => tensor.to_device(&self.default_device)?,
None => panic!("cannot find tensor for {path}"), None => panic!("cannot find tensor for {path}"),
}, },
}; };
vars.push(NamedVar {
path,
dtype: self.default_dtype,
shape,
});
Ok(parameter) Ok(parameter)
} }
@ -90,3 +85,88 @@ impl<S: ToString> std::ops::Div<S> for VarBuilder {
&self / rhs &self / rhs
} }
} }
impl Embedding {
fn load_npy(vb: VarBuilder) -> Result<Self> {
let embeddings = vb.var("weight")?;
Ok(Self { embeddings })
}
}
impl Linear {
fn load_npy(vb: VarBuilder) -> Result<Self> {
let weight = vb.var("weight")?.t()?;
Ok(Self { weight })
}
}
impl RmsNorm {
fn load_npy(vb: VarBuilder) -> Result<Self> {
let scale = vb.var("scale")?;
Ok(Self::new(scale))
}
}
impl CausalSelfAttention {
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
let c_attn = Linear::load_npy(&vb / "c_attn")?;
let c_proj = Linear::load_npy(&vb / "c_proj")?;
Ok(Self::new(c_attn, c_proj, config.n_head, cache))
}
}
impl Mlp {
fn load_npy(vb: VarBuilder) -> Result<Self> {
let c_fc1 = Linear::load_npy(&vb / "c_fc1")?;
let c_fc2 = Linear::load_npy(&vb / "c_fc2")?;
let c_proj = Linear::load_npy(&vb / "c_proj")?;
Ok(Self::new(c_fc1, c_fc2, c_proj))
}
}
impl Block {
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load_npy(&vb / "attn", cache, config)?;
let mlp = Mlp::load_npy(&vb / "mlp")?;
let input_layernorm = RmsNorm::load_npy(&vb / "rms_1")?;
let post_attention_layernorm = RmsNorm::load_npy(&vb / "rms_2")?;
Ok(Self::new(
input_layernorm,
attn,
post_attention_layernorm,
mlp,
))
}
}
impl Llama {
pub fn load_npy(
device: &Device,
_filenames: &[PathBuf],
cache: &Cache,
config: &Config,
) -> Result<Self> {
let weight_path = std::path::Path::new("/data/llama.npz");
let weights = if weight_path.exists() {
println!("loading weights from {weight_path:?}");
let start_load = std::time::Instant::now();
let tensors = Tensor::read_npz(weight_path)?;
println!("loaded weights in {:?}", start_load.elapsed());
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
Some(tensors)
} else {
println!("cannot find {weight_path:?}, using zero weights");
None
};
let vb = VarBuilder::new::<f32>(device, weights);
let wte = Embedding::load_npy(&vb / "transformer" / "wte")?;
let lm_head = Linear::load_npy(&vb / "lm_head")?;
let norm = RmsNorm::load_npy(&vb / "transformer" / "ln_f")?;
let blocks: Vec<_> = (0..config.n_layer)
.map(|i| Block::load_npy(&vb / "transformer" / "h" / i, cache, config).unwrap())
.collect();
Ok(Self::new(wte, blocks, norm, lm_head))
}
}