diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 83a1d69a..3a025683 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -13,7 +13,7 @@ // transposition operations. use anyhow::{Error as E, Result}; use clap::Parser; -use rand::{distributions::Distribution, SeedableRng}; +use rand::{distributions::Distribution, thread_rng}; use candle::{DType, Device, Tensor}; use candle_hub::{api::Api, Repo, RepoType}; @@ -138,7 +138,7 @@ impl Embedding { } fn forward(&self, indexes: &Tensor) -> Result { - Ok(Tensor::embedding(indexes, &self.embeddings).unwrap()) + Ok(Tensor::embedding(indexes, &self.embeddings)?) } } @@ -152,7 +152,7 @@ impl Linear { } fn forward(&self, x: &Tensor) -> Result { - let x = x.matmul(&self.weight.t().unwrap()).unwrap(); + let x = x.matmul(&self.weight.t()?)?; Ok(x) } } @@ -168,18 +168,16 @@ impl RmsNorm { fn forward(&self, x: &Tensor) -> Result { let x = x.to_dtype(DType::F32)?; - let (seq_len, hidden_size) = x.shape().r2().unwrap(); - let norm_x = ((&x * &x).unwrap().sum(&[1]).unwrap() / hidden_size as f64).unwrap(); - let norm_x = norm_x.broadcast_as((seq_len, hidden_size)).unwrap(); - let x_normed = (x / (norm_x + 1e-5).unwrap().sqrt().unwrap()).unwrap(); - let size = self.scale.shape().r1().unwrap(); + let (seq_len, hidden_size) = x.shape().r2()?; + let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?; + let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; + let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; + let size = self.scale.shape().r1()?; let scale = self .scale - .to_dtype(DType::F32) - .unwrap() - .broadcast_as((seq_len, size)) - .unwrap(); - let x = (scale * x_normed).unwrap(); + .to_dtype(DType::F32)? + .broadcast_as((seq_len, size))?; + let x = (scale * x_normed)?; let x = x.to_dtype(DType::F16)?; Ok(x) } @@ -192,7 +190,7 @@ struct Mlp { } fn silu(xs: &Tensor) -> Result { - Ok((xs / (xs.neg().unwrap().exp().unwrap() + 1.0).unwrap()).unwrap()) + Ok((xs / (xs.neg()?.exp()? + 1.0)?)?) } impl Mlp { @@ -205,19 +203,15 @@ impl Mlp { } fn forward(&self, x: &Tensor) -> Result { - let x = (silu(&self.c_fc1.forward(x).unwrap()).unwrap() * self.c_fc2.forward(x).unwrap()) - .unwrap(); + let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; self.c_proj.forward(&x) } } fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { let shape = mask.shape(); - let on_true = Tensor::new(on_true, &on_false.device()) - .unwrap() - .broadcast_as(shape.dims()) - .unwrap(); - let m = mask.where_cond(&on_true, on_false).unwrap(); + let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; Ok(m) } @@ -244,7 +238,7 @@ impl Cache { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &self.device).unwrap(); + let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; masks.insert(t, mask.clone()); Ok(mask) } @@ -274,70 +268,47 @@ impl CausalSelfAttention { let v = dims.pop().unwrap(); dims.push(v / 2); dims.push(2); - let x = x.reshape(dims).unwrap(); + let x = x.reshape(dims)?; let rank = x.rank(); - let re_x = x.narrow(rank - 1, 0, 1).unwrap(); - let im_x = x.narrow(rank - 1, 1, 1).unwrap(); + let re_x = x.narrow(rank - 1, 0, 1)?; + let im_x = x.narrow(rank - 1, 1, 1)?; let re_f = freqs_cis - .narrow(rank - 1, 0, 1) - .unwrap() - .broadcast_as(re_x.shape()) - .unwrap(); + .narrow(rank - 1, 0, 1)? + .broadcast_as(re_x.shape())?; let im_f = freqs_cis - .narrow(rank - 1, 1, 1) - .unwrap() - .broadcast_as(im_x.shape()) - .unwrap(); - let re = ((&re_x * &re_f).unwrap() - (&im_x * &im_f).unwrap()).unwrap(); - let im = ((&re_x * &im_f).unwrap() + (&im_x * &re_f).unwrap()).unwrap(); - let rope = Tensor::cat(&[&re, &im], rank - 1).unwrap(); - let rope = rope.flatten(Some(rope.rank() - 2), None).unwrap(); + .narrow(rank - 1, 1, 1)? + .broadcast_as(im_x.shape())?; + let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; + let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; + let rope = Tensor::cat(&[&re, &im], rank - 1)?; + let rope = rope.flatten(Some(rope.rank() - 2), None)?; Ok(rope) } fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { - let (t, c) = x.shape().r2().unwrap(); - let qkv = self.c_attn.forward(x).unwrap(); - let qkv = qkv.to_dtype(DType::F32).unwrap(); + let (t, c) = x.shape().r2()?; + let qkv = self.c_attn.forward(x)?; + let qkv = qkv.to_dtype(DType::F32)?; let n_embd = c; - let q = qkv.narrow(1, 0, n_embd).unwrap(); - let k = qkv.narrow(1, n_embd, n_embd).unwrap(); - let v = qkv.narrow(1, 2 * n_embd, n_embd).unwrap(); + let q = qkv.narrow(1, 0, n_embd)?; + let k = qkv.narrow(1, n_embd, n_embd)?; + let v = qkv.narrow(1, 2 * n_embd, n_embd)?; let target_dim = [t, self.n_head, c / self.n_head]; - let k = k - .reshape(target_dim.as_slice()) - .unwrap() - .transpose(0, 1) - .unwrap(); - let q = q - .reshape(target_dim.as_slice()) - .unwrap() - .transpose(0, 1) - .unwrap(); - let v = v - .reshape(target_dim.as_slice()) - .unwrap() - .transpose(0, 1) - .unwrap(); - let q = self.apply_rotary_emb(&q, freqs_cis).unwrap(); - let k = self.apply_rotary_emb(&k, freqs_cis).unwrap(); + let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?; + let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?; + let v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?; + let q = self.apply_rotary_emb(&q, freqs_cis)?; + let k = self.apply_rotary_emb(&k, freqs_cis)?; let k_shape = k.shape(); - let att = (q.matmul(&k.t().unwrap()).unwrap() - / (*k_shape.dims().last().unwrap() as f64).sqrt()) - .unwrap(); - let mask = self - .cache - .mask(t) - .unwrap() - .broadcast_as(att.shape()) - .unwrap(); - let att = masked_fill(&att, &mask, f32::NEG_INFINITY).unwrap(); - let att = att.softmax(att.rank() - 1).unwrap(); + let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?; + let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = att.softmax(att.rank() - 1)?; // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous().unwrap()).unwrap(); - let y = y.transpose(0, 1).unwrap().reshape(&[t, c]).unwrap(); - let y = y.to_dtype(DType::F16).unwrap(); - let y = self.c_proj.forward(&y).unwrap(); + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(0, 1)?.reshape(&[t, c])?; + let y = y.to_dtype(DType::F16)?; + let y = self.c_proj.forward(&y)?; Ok(y) } } @@ -360,13 +331,8 @@ impl Block { } fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { - let x = (self - .attn - .forward(&self.rms_1.forward(x).unwrap(), freqs_cis) - .unwrap() - + x) - .unwrap(); - let x = (self.mlp.forward(&self.rms_2.forward(&x).unwrap()).unwrap() + x).unwrap(); + let x = (self.attn.forward(&self.rms_1.forward(x)?, freqs_cis)? + x)?; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?; Ok(x) } } @@ -390,18 +356,18 @@ impl Llama { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { // TODO: Support for mini-batches? (i.e. r2) - let t = x.shape().r1().unwrap(); - let mut x = self.wte.forward(x).unwrap(); + let t = x.shape().r1()?; + let mut x = self.wte.forward(x)?; for block in self.blocks.iter() { - x = block.forward(&x, freqs_cis).unwrap(); + x = block.forward(&x, freqs_cis)?; } - let x = self.ln_f.forward(&x).unwrap(); - let x = x.narrow(0, t - 1, 1).unwrap(); - let logits = self.lm_head.forward(&x).unwrap(); + let x = self.ln_f.forward(&x)?; + let x = x.narrow(0, t - 1, 1)?; + let logits = self.lm_head.forward(&x)?; let logits = logits.to_dtype(DType::F32)?; - let (b, vocab_size) = logits.shape().r2().unwrap(); + let (b, vocab_size) = logits.shape().r2()?; assert_eq!(b, 1); - Ok(logits.reshape(vocab_size).unwrap()) + Ok(logits.reshape(vocab_size)?) } } @@ -413,18 +379,16 @@ fn precompute_freqs_cis(config: &Config, device: &Device) -> Result { .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) .collect(); let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect(); - let theta = Tensor::new(theta.as_slice(), device).unwrap(); - let arange = Tensor::new(arange.as_slice(), device).unwrap(); + let theta = Tensor::new(theta.as_slice(), device)?; + let arange = Tensor::new(arange.as_slice(), device)?; let idx_theta = arange - .reshape((arange.elem_count(), 1)) - .unwrap() - .matmul(&theta.reshape((1, theta.elem_count())).unwrap()) - .unwrap(); + .reshape((arange.elem_count(), 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; let shape = [1, seq_len, n_elem / 2, 1]; - let idx_theta_cos = idx_theta.cos().unwrap().reshape(&shape).unwrap(); - let idx_theta_sin = idx_theta.sin().unwrap().reshape(&shape).unwrap(); + let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?; + let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?; let last_dim = idx_theta_cos.rank() - 1; - Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim).unwrap()) + Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim)?) } #[derive(Parser, Debug)] @@ -442,10 +406,6 @@ struct Args { #[arg(long)] temperature: Option, - /// The seed to use when generating random samples. - #[arg(long, default_value_t = 299792458)] - seed: u64, - /// The length of the sample to generate (in tokens). #[arg(long, default_value_t = 100)] sample_len: usize, @@ -453,26 +413,28 @@ struct Args { #[tokio::main] async fn main() -> Result<()> { + //use rand::prelude::*; use tokenizers::Tokenizer; let args = Args::parse(); let device = if args.cpu { Device::Cpu } else { - Device::new_cuda(0).unwrap() + Device::new_cuda(0)? }; + let api = Api::new()?; + let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); + println!("building the model"); let config = Config::config_7b(); let cache = Cache::new(&device); let start = std::time::Instant::now(); let (llama, tokenizer_filename) = if args.npy { println!("building the model (NPY)"); ( - Llama::load_npy(&device, "/data/llama.npz", &cache, &config).unwrap(), + Llama::load_npy(&device, "/data/llama.npz", &cache, &config)?, std::path::Path::new("llama-tokenizer.json").to_path_buf(), ) } else { - let api = Api::new()?; - let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; let mut filenames = vec![]; for rfilename in [ @@ -485,51 +447,50 @@ async fn main() -> Result<()> { println!("building the model (SF)"); ( - Llama::load(&device, &filenames, &cache, &config).unwrap(), + Llama::load(&device, &filenames, &cache, &config)?, tokenizer_filename, ) }; println!("Loaded in {:?}", start.elapsed()); - let tokenizer = Tokenizer::from_file(tokenizer_filename) - .map_err(E::msg) - .unwrap(); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let mut tokens = tokenizer .encode(START_PROMPT, true) - .map_err(E::msg) - .unwrap() + .map_err(E::msg)? .get_ids() .to_vec(); println!("pre-computing the positional embeddings"); - let freqs_cis = precompute_freqs_cis(&config, &device).unwrap(); + let freqs_cis = precompute_freqs_cis(&config, &device)?; println!("starting the inference loop"); let mut new_tokens = vec![]; - let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed); + let mut rng = thread_rng(); let start_gen = std::time::Instant::now(); for index in 0..args.sample_len { let start_gen = std::time::Instant::now(); let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; - let input = Tensor::new(ctxt, &device).unwrap(); - let logits = llama.forward(&input, &freqs_cis).unwrap(); + let input = Tensor::new(ctxt, &device)?; + let logits = llama.forward(&input, &freqs_cis)?; let next_token = if let Some(temperature) = args.temperature { println!("Sampling with temperature {temperature:?}"); - let prs = (&logits / temperature) - .unwrap() - .softmax(logits.rank() - 1) - .unwrap(); - let logits_v: Vec = prs.to_vec1().unwrap(); - let distr = rand::distributions::WeightedIndex::new(&logits_v).unwrap(); + let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?; + let logits_v: Vec = prs.to_vec1()?; + let distr = rand::distributions::WeightedIndex::new(&logits_v)?; distr.sample(&mut rng) as u32 } else { - let logits_v: Vec = logits.to_vec1().unwrap(); + let logits_v: Vec = logits.to_vec1()?; logits_v .iter() .enumerate() - .max_by(|(_, u), (_, v)| u.total_cmp(v)) - .map(|(i, _)| i as u32) - .unwrap() + .fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| { + if &val_max > val { + (idx_max, val_max) + } else { + (idx, *val) + } + }) + .0 as u32 }; tokens.push(next_token); new_tokens.push(next_token); @@ -538,10 +499,7 @@ async fn main() -> Result<()> { "{} token: {} '{}'", index + 1, next_token, - tokenizer - .decode(vec![next_token], true) - .map_err(E::msg) - .unwrap() + tokenizer.decode(vec![next_token], true).map_err(E::msg)? ); } let dt = start_gen.elapsed(); @@ -549,7 +507,7 @@ async fn main() -> Result<()> { "{} tokens generated ({} token/s)\n----\n{}\n----", args.sample_len, args.sample_len as f64 / dt.as_secs_f64(), - tokenizer.decode(new_tokens, true).map_err(E::msg).unwrap() + tokenizer.decode(new_tokens, true).map_err(E::msg)? ); Ok(()) }