Remove unwrap.

This commit is contained in:
Nicolas Patry
2023-06-29 12:04:25 +00:00
parent 2fe1d3e36d
commit e63ed6aaa3

View File

@ -13,7 +13,7 @@
// transposition operations. // transposition operations.
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::Parser; use clap::Parser;
use rand::{distributions::Distribution, SeedableRng}; use rand::{distributions::Distribution, thread_rng};
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_hub::{api::Api, Repo, RepoType}; use candle_hub::{api::Api, Repo, RepoType};
@ -138,7 +138,7 @@ impl Embedding {
} }
fn forward(&self, indexes: &Tensor) -> Result<Tensor> { fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
Ok(Tensor::embedding(indexes, &self.embeddings).unwrap()) Ok(Tensor::embedding(indexes, &self.embeddings)?)
} }
} }
@ -152,7 +152,7 @@ impl Linear {
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.weight.t().unwrap()).unwrap(); let x = x.matmul(&self.weight.t()?)?;
Ok(x) Ok(x)
} }
} }
@ -168,18 +168,16 @@ impl RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.to_dtype(DType::F32)?; let x = x.to_dtype(DType::F32)?;
let (seq_len, hidden_size) = x.shape().r2().unwrap(); let (seq_len, hidden_size) = x.shape().r2()?;
let norm_x = ((&x * &x).unwrap().sum(&[1]).unwrap() / hidden_size as f64).unwrap(); let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((seq_len, hidden_size)).unwrap(); let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5).unwrap().sqrt().unwrap()).unwrap(); let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().r1().unwrap(); let size = self.scale.shape().r1()?;
let scale = self let scale = self
.scale .scale
.to_dtype(DType::F32) .to_dtype(DType::F32)?
.unwrap() .broadcast_as((seq_len, size))?;
.broadcast_as((seq_len, size)) let x = (scale * x_normed)?;
.unwrap();
let x = (scale * x_normed).unwrap();
let x = x.to_dtype(DType::F16)?; let x = x.to_dtype(DType::F16)?;
Ok(x) Ok(x)
} }
@ -192,7 +190,7 @@ struct Mlp {
} }
fn silu(xs: &Tensor) -> Result<Tensor> { fn silu(xs: &Tensor) -> Result<Tensor> {
Ok((xs / (xs.neg().unwrap().exp().unwrap() + 1.0).unwrap()).unwrap()) Ok((xs / (xs.neg()?.exp()? + 1.0)?)?)
} }
impl Mlp { impl Mlp {
@ -205,19 +203,15 @@ impl Mlp {
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x).unwrap()).unwrap() * self.c_fc2.forward(x).unwrap()) let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
.unwrap();
self.c_proj.forward(&x) self.c_proj.forward(&x)
} }
} }
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape(); let shape = mask.shape();
let on_true = Tensor::new(on_true, &on_false.device()) let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
.unwrap() let m = mask.where_cond(&on_true, on_false)?;
.broadcast_as(shape.dims())
.unwrap();
let m = mask.where_cond(&on_true, on_false).unwrap();
Ok(m) Ok(m)
} }
@ -244,7 +238,7 @@ impl Cache {
let mask: Vec<_> = (0..t) let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j > i))) .flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
.collect(); .collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device).unwrap(); let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone()); masks.insert(t, mask.clone());
Ok(mask) Ok(mask)
} }
@ -274,70 +268,47 @@ impl CausalSelfAttention {
let v = dims.pop().unwrap(); let v = dims.pop().unwrap();
dims.push(v / 2); dims.push(v / 2);
dims.push(2); dims.push(2);
let x = x.reshape(dims).unwrap(); let x = x.reshape(dims)?;
let rank = x.rank(); let rank = x.rank();
let re_x = x.narrow(rank - 1, 0, 1).unwrap(); let re_x = x.narrow(rank - 1, 0, 1)?;
let im_x = x.narrow(rank - 1, 1, 1).unwrap(); let im_x = x.narrow(rank - 1, 1, 1)?;
let re_f = freqs_cis let re_f = freqs_cis
.narrow(rank - 1, 0, 1) .narrow(rank - 1, 0, 1)?
.unwrap() .broadcast_as(re_x.shape())?;
.broadcast_as(re_x.shape())
.unwrap();
let im_f = freqs_cis let im_f = freqs_cis
.narrow(rank - 1, 1, 1) .narrow(rank - 1, 1, 1)?
.unwrap() .broadcast_as(im_x.shape())?;
.broadcast_as(im_x.shape()) let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
.unwrap(); let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
let re = ((&re_x * &re_f).unwrap() - (&im_x * &im_f).unwrap()).unwrap(); let rope = Tensor::cat(&[&re, &im], rank - 1)?;
let im = ((&re_x * &im_f).unwrap() + (&im_x * &re_f).unwrap()).unwrap(); let rope = rope.flatten(Some(rope.rank() - 2), None)?;
let rope = Tensor::cat(&[&re, &im], rank - 1).unwrap();
let rope = rope.flatten(Some(rope.rank() - 2), None).unwrap();
Ok(rope) Ok(rope)
} }
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let (t, c) = x.shape().r2().unwrap(); let (t, c) = x.shape().r2()?;
let qkv = self.c_attn.forward(x).unwrap(); let qkv = self.c_attn.forward(x)?;
let qkv = qkv.to_dtype(DType::F32).unwrap(); let qkv = qkv.to_dtype(DType::F32)?;
let n_embd = c; let n_embd = c;
let q = qkv.narrow(1, 0, n_embd).unwrap(); let q = qkv.narrow(1, 0, n_embd)?;
let k = qkv.narrow(1, n_embd, n_embd).unwrap(); let k = qkv.narrow(1, n_embd, n_embd)?;
let v = qkv.narrow(1, 2 * n_embd, n_embd).unwrap(); let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
let target_dim = [t, self.n_head, c / self.n_head]; let target_dim = [t, self.n_head, c / self.n_head];
let k = k let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
.reshape(target_dim.as_slice()) let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
.unwrap() let v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
.transpose(0, 1) let q = self.apply_rotary_emb(&q, freqs_cis)?;
.unwrap(); let k = self.apply_rotary_emb(&k, freqs_cis)?;
let q = q
.reshape(target_dim.as_slice())
.unwrap()
.transpose(0, 1)
.unwrap();
let v = v
.reshape(target_dim.as_slice())
.unwrap()
.transpose(0, 1)
.unwrap();
let q = self.apply_rotary_emb(&q, freqs_cis).unwrap();
let k = self.apply_rotary_emb(&k, freqs_cis).unwrap();
let k_shape = k.shape(); let k_shape = k.shape();
let att = (q.matmul(&k.t().unwrap()).unwrap() let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
/ (*k_shape.dims().last().unwrap() as f64).sqrt()) let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
.unwrap(); let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let mask = self let att = att.softmax(att.rank() - 1)?;
.cache
.mask(t)
.unwrap()
.broadcast_as(att.shape())
.unwrap();
let att = masked_fill(&att, &mask, f32::NEG_INFINITY).unwrap();
let att = att.softmax(att.rank() - 1).unwrap();
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous().unwrap()).unwrap(); let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(0, 1).unwrap().reshape(&[t, c]).unwrap(); let y = y.transpose(0, 1)?.reshape(&[t, c])?;
let y = y.to_dtype(DType::F16).unwrap(); let y = y.to_dtype(DType::F16)?;
let y = self.c_proj.forward(&y).unwrap(); let y = self.c_proj.forward(&y)?;
Ok(y) Ok(y)
} }
} }
@ -360,13 +331,8 @@ impl Block {
} }
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let x = (self let x = (self.attn.forward(&self.rms_1.forward(x)?, freqs_cis)? + x)?;
.attn let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
.forward(&self.rms_1.forward(x).unwrap(), freqs_cis)
.unwrap()
+ x)
.unwrap();
let x = (self.mlp.forward(&self.rms_2.forward(&x).unwrap()).unwrap() + x).unwrap();
Ok(x) Ok(x)
} }
} }
@ -390,18 +356,18 @@ impl Llama {
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
// TODO: Support for mini-batches? (i.e. r2) // TODO: Support for mini-batches? (i.e. r2)
let t = x.shape().r1().unwrap(); let t = x.shape().r1()?;
let mut x = self.wte.forward(x).unwrap(); let mut x = self.wte.forward(x)?;
for block in self.blocks.iter() { for block in self.blocks.iter() {
x = block.forward(&x, freqs_cis).unwrap(); x = block.forward(&x, freqs_cis)?;
} }
let x = self.ln_f.forward(&x).unwrap(); let x = self.ln_f.forward(&x)?;
let x = x.narrow(0, t - 1, 1).unwrap(); let x = x.narrow(0, t - 1, 1)?;
let logits = self.lm_head.forward(&x).unwrap(); let logits = self.lm_head.forward(&x)?;
let logits = logits.to_dtype(DType::F32)?; let logits = logits.to_dtype(DType::F32)?;
let (b, vocab_size) = logits.shape().r2().unwrap(); let (b, vocab_size) = logits.shape().r2()?;
assert_eq!(b, 1); assert_eq!(b, 1);
Ok(logits.reshape(vocab_size).unwrap()) Ok(logits.reshape(vocab_size)?)
} }
} }
@ -413,18 +379,16 @@ fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect(); .collect();
let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect(); let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect();
let theta = Tensor::new(theta.as_slice(), device).unwrap(); let theta = Tensor::new(theta.as_slice(), device)?;
let arange = Tensor::new(arange.as_slice(), device).unwrap(); let arange = Tensor::new(arange.as_slice(), device)?;
let idx_theta = arange let idx_theta = arange
.reshape((arange.elem_count(), 1)) .reshape((arange.elem_count(), 1))?
.unwrap() .matmul(&theta.reshape((1, theta.elem_count()))?)?;
.matmul(&theta.reshape((1, theta.elem_count())).unwrap())
.unwrap();
let shape = [1, seq_len, n_elem / 2, 1]; let shape = [1, seq_len, n_elem / 2, 1];
let idx_theta_cos = idx_theta.cos().unwrap().reshape(&shape).unwrap(); let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
let idx_theta_sin = idx_theta.sin().unwrap().reshape(&shape).unwrap(); let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
let last_dim = idx_theta_cos.rank() - 1; let last_dim = idx_theta_cos.rank() - 1;
Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim).unwrap()) Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim)?)
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -442,10 +406,6 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)] #[arg(long, default_value_t = 100)]
sample_len: usize, sample_len: usize,
@ -453,26 +413,28 @@ struct Args {
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
//use rand::prelude::*;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
let args = Args::parse(); let args = Args::parse();
let device = if args.cpu { let device = if args.cpu {
Device::Cpu Device::Cpu
} else { } else {
Device::new_cuda(0).unwrap() Device::new_cuda(0)?
}; };
let api = Api::new()?;
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
println!("building the model");
let config = Config::config_7b(); let config = Config::config_7b();
let cache = Cache::new(&device); let cache = Cache::new(&device);
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let (llama, tokenizer_filename) = if args.npy { let (llama, tokenizer_filename) = if args.npy {
println!("building the model (NPY)"); println!("building the model (NPY)");
( (
Llama::load_npy(&device, "/data/llama.npz", &cache, &config).unwrap(), Llama::load_npy(&device, "/data/llama.npz", &cache, &config)?,
std::path::Path::new("llama-tokenizer.json").to_path_buf(), std::path::Path::new("llama-tokenizer.json").to_path_buf(),
) )
} else { } else {
let api = Api::new()?;
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
let mut filenames = vec![]; let mut filenames = vec![];
for rfilename in [ for rfilename in [
@ -485,51 +447,50 @@ async fn main() -> Result<()> {
println!("building the model (SF)"); println!("building the model (SF)");
( (
Llama::load(&device, &filenames, &cache, &config).unwrap(), Llama::load(&device, &filenames, &cache, &config)?,
tokenizer_filename, tokenizer_filename,
) )
}; };
println!("Loaded in {:?}", start.elapsed()); println!("Loaded in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename) let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
.map_err(E::msg)
.unwrap();
let mut tokens = tokenizer let mut tokens = tokenizer
.encode(START_PROMPT, true) .encode(START_PROMPT, true)
.map_err(E::msg) .map_err(E::msg)?
.unwrap()
.get_ids() .get_ids()
.to_vec(); .to_vec();
println!("pre-computing the positional embeddings"); println!("pre-computing the positional embeddings");
let freqs_cis = precompute_freqs_cis(&config, &device).unwrap(); let freqs_cis = precompute_freqs_cis(&config, &device)?;
println!("starting the inference loop"); println!("starting the inference loop");
let mut new_tokens = vec![]; let mut new_tokens = vec![];
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed); let mut rng = thread_rng();
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0..args.sample_len { for index in 0..args.sample_len {
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &device).unwrap(); let input = Tensor::new(ctxt, &device)?;
let logits = llama.forward(&input, &freqs_cis).unwrap(); let logits = llama.forward(&input, &freqs_cis)?;
let next_token = if let Some(temperature) = args.temperature { let next_token = if let Some(temperature) = args.temperature {
println!("Sampling with temperature {temperature:?}"); println!("Sampling with temperature {temperature:?}");
let prs = (&logits / temperature) let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?;
.unwrap() let logits_v: Vec<f32> = prs.to_vec1()?;
.softmax(logits.rank() - 1) let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
.unwrap();
let logits_v: Vec<f32> = prs.to_vec1().unwrap();
let distr = rand::distributions::WeightedIndex::new(&logits_v).unwrap();
distr.sample(&mut rng) as u32 distr.sample(&mut rng) as u32
} else { } else {
let logits_v: Vec<f32> = logits.to_vec1().unwrap(); let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v logits_v
.iter() .iter()
.enumerate() .enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v)) .fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
.map(|(i, _)| i as u32) if &val_max > val {
.unwrap() (idx_max, val_max)
} else {
(idx, *val)
}
})
.0 as u32
}; };
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); new_tokens.push(next_token);
@ -538,10 +499,7 @@ async fn main() -> Result<()> {
"{} token: {} '{}'", "{} token: {} '{}'",
index + 1, index + 1,
next_token, next_token,
tokenizer tokenizer.decode(vec![next_token], true).map_err(E::msg)?
.decode(vec![next_token], true)
.map_err(E::msg)
.unwrap()
); );
} }
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
@ -549,7 +507,7 @@ async fn main() -> Result<()> {
"{} tokens generated ({} token/s)\n----\n{}\n----", "{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len, args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(), args.sample_len as f64 / dt.as_secs_f64(),
tokenizer.decode(new_tokens, true).map_err(E::msg).unwrap() tokenizer.decode(new_tokens, true).map_err(E::msg)?
); );
Ok(()) Ok(())
} }