mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Moving llama to f16.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
// An implementation of LLaMA https://github.com/facebookresearch/llama
|
||||
// An implementation of LLaMA https://github.com/facebookresearch/llama");");");
|
||||
//
|
||||
// This is based on nanoGPT in a similar way to:
|
||||
// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
|
||||
// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py");
|
||||
//
|
||||
// The tokenizer config can be retrieved from:
|
||||
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
|
||||
@ -138,7 +138,7 @@ impl Embedding {
|
||||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
Ok(Tensor::embedding(indexes, &self.embeddings)?)
|
||||
Ok(Tensor::embedding(indexes, &self.embeddings).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,7 +152,7 @@ impl Linear {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.weight.to_dtype(DType::F32)?.t()?)?;
|
||||
let x = x.matmul(&self.weight.t().unwrap()).unwrap();
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
@ -167,16 +167,21 @@ impl RmsNorm {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
||||
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let size = self.scale.shape().r1()?;
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (seq_len, hidden_size) = x.shape().r2().unwrap();
|
||||
let norm_x = ((&x * &x).unwrap().sum(&[1]).unwrap() / hidden_size as f64).unwrap();
|
||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size)).unwrap();
|
||||
let x_normed = (x / (norm_x + 1e-5).unwrap().sqrt().unwrap()).unwrap();
|
||||
let size = self.scale.shape().r1().unwrap();
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((seq_len, size))?;
|
||||
Ok((scale * x_normed)?)
|
||||
.to_dtype(DType::F32)
|
||||
.unwrap()
|
||||
.broadcast_as((seq_len, size))
|
||||
.unwrap();
|
||||
let x = (scale * x_normed).unwrap();
|
||||
let x = x.to_dtype(DType::F16)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
@ -187,7 +192,7 @@ struct Mlp {
|
||||
}
|
||||
|
||||
fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
Ok((xs / (xs.neg()?.exp()? + 1.0)?)?)
|
||||
Ok((xs / (xs.neg().unwrap().exp().unwrap() + 1.0).unwrap()).unwrap())
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
@ -200,15 +205,19 @@ impl Mlp {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
let x = (silu(&self.c_fc1.forward(x).unwrap()).unwrap() * self.c_fc2.forward(x).unwrap())
|
||||
.unwrap();
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
let on_true = Tensor::new(on_true, &on_false.device())
|
||||
.unwrap()
|
||||
.broadcast_as(shape.dims())
|
||||
.unwrap();
|
||||
let m = mask.where_cond(&on_true, on_false).unwrap();
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
@ -235,7 +244,7 @@ impl Cache {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device).unwrap();
|
||||
masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
@ -265,45 +274,70 @@ impl CausalSelfAttention {
|
||||
let v = dims.pop().unwrap();
|
||||
dims.push(v / 2);
|
||||
dims.push(2);
|
||||
let x = x.reshape(dims)?;
|
||||
let x = x.reshape(dims).unwrap();
|
||||
let rank = x.rank();
|
||||
let re_x = x.narrow(rank - 1, 0, 1)?;
|
||||
let im_x = x.narrow(rank - 1, 1, 1)?;
|
||||
let re_x = x.narrow(rank - 1, 0, 1).unwrap();
|
||||
let im_x = x.narrow(rank - 1, 1, 1).unwrap();
|
||||
let re_f = freqs_cis
|
||||
.narrow(rank - 1, 0, 1)?
|
||||
.broadcast_as(re_x.shape())?;
|
||||
.narrow(rank - 1, 0, 1)
|
||||
.unwrap()
|
||||
.broadcast_as(re_x.shape())
|
||||
.unwrap();
|
||||
let im_f = freqs_cis
|
||||
.narrow(rank - 1, 1, 1)?
|
||||
.broadcast_as(im_x.shape())?;
|
||||
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
|
||||
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
|
||||
let rope = Tensor::cat(&[&re, &im], rank - 1)?;
|
||||
let rope = rope.flatten(Some(rope.rank() - 2), None)?;
|
||||
.narrow(rank - 1, 1, 1)
|
||||
.unwrap()
|
||||
.broadcast_as(im_x.shape())
|
||||
.unwrap();
|
||||
let re = ((&re_x * &re_f).unwrap() - (&im_x * &im_f).unwrap()).unwrap();
|
||||
let im = ((&re_x * &im_f).unwrap() + (&im_x * &re_f).unwrap()).unwrap();
|
||||
let rope = Tensor::cat(&[&re, &im], rank - 1).unwrap();
|
||||
let rope = rope.flatten(Some(rope.rank() - 2), None).unwrap();
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
let (t, c) = x.shape().r2()?;
|
||||
let qkv = self.c_attn.forward(x)?;
|
||||
let (t, c) = x.shape().r2().unwrap();
|
||||
let qkv = self.c_attn.forward(x).unwrap();
|
||||
let qkv = qkv.to_dtype(DType::F32).unwrap();
|
||||
let n_embd = c;
|
||||
let q = qkv.narrow(1, 0, n_embd)?;
|
||||
let k = qkv.narrow(1, n_embd, n_embd)?;
|
||||
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
|
||||
let q = qkv.narrow(1, 0, n_embd).unwrap();
|
||||
let k = qkv.narrow(1, n_embd, n_embd).unwrap();
|
||||
let v = qkv.narrow(1, 2 * n_embd, n_embd).unwrap();
|
||||
let target_dim = [t, self.n_head, c / self.n_head];
|
||||
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||
let v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
||||
let k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||
let k = k
|
||||
.reshape(target_dim.as_slice())
|
||||
.unwrap()
|
||||
.transpose(0, 1)
|
||||
.unwrap();
|
||||
let q = q
|
||||
.reshape(target_dim.as_slice())
|
||||
.unwrap()
|
||||
.transpose(0, 1)
|
||||
.unwrap();
|
||||
let v = v
|
||||
.reshape(target_dim.as_slice())
|
||||
.unwrap()
|
||||
.transpose(0, 1)
|
||||
.unwrap();
|
||||
let q = self.apply_rotary_emb(&q, freqs_cis).unwrap();
|
||||
let k = self.apply_rotary_emb(&k, freqs_cis).unwrap();
|
||||
let k_shape = k.shape();
|
||||
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
|
||||
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(att.rank() - 1)?;
|
||||
let att = (q.matmul(&k.t().unwrap()).unwrap()
|
||||
/ (*k_shape.dims().last().unwrap() as f64).sqrt())
|
||||
.unwrap();
|
||||
let mask = self
|
||||
.cache
|
||||
.mask(t)
|
||||
.unwrap()
|
||||
.broadcast_as(att.shape())
|
||||
.unwrap();
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY).unwrap();
|
||||
let att = att.softmax(att.rank() - 1).unwrap();
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
|
||||
let y = self.c_proj.forward(&y)?;
|
||||
let y = att.matmul(&v.contiguous().unwrap()).unwrap();
|
||||
let y = y.transpose(0, 1).unwrap().reshape(&[t, c]).unwrap();
|
||||
let y = y.to_dtype(DType::F16).unwrap();
|
||||
let y = self.c_proj.forward(&y).unwrap();
|
||||
Ok(y)
|
||||
}
|
||||
}
|
||||
@ -326,8 +360,13 @@ impl Block {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
let x = (self.attn.forward(&self.rms_1.forward(x)?, freqs_cis)? + x)?;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
|
||||
let x = (self
|
||||
.attn
|
||||
.forward(&self.rms_1.forward(x).unwrap(), freqs_cis)
|
||||
.unwrap()
|
||||
+ x)
|
||||
.unwrap();
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x).unwrap()).unwrap() + x).unwrap();
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
@ -351,18 +390,18 @@ impl Llama {
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Support for mini-batches? (i.e. r2)
|
||||
let t = x.shape().r1()?;
|
||||
let x = self.wte.forward(x)?;
|
||||
let mut x = x.to_dtype(DType::F32)?;
|
||||
let t = x.shape().r1().unwrap();
|
||||
let mut x = self.wte.forward(x).unwrap();
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, freqs_cis)?;
|
||||
x = block.forward(&x, freqs_cis).unwrap();
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.narrow(0, t - 1, 1)?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
let (b, vocab_size) = logits.shape().r2()?;
|
||||
let x = self.ln_f.forward(&x).unwrap();
|
||||
let x = x.narrow(0, t - 1, 1).unwrap();
|
||||
let logits = self.lm_head.forward(&x).unwrap();
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let (b, vocab_size) = logits.shape().r2().unwrap();
|
||||
assert_eq!(b, 1);
|
||||
Ok(logits.reshape(vocab_size)?)
|
||||
Ok(logits.reshape(vocab_size).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
@ -374,16 +413,18 @@ fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let arange = Tensor::new(arange.as_slice(), device)?;
|
||||
let theta = Tensor::new(theta.as_slice(), device).unwrap();
|
||||
let arange = Tensor::new(arange.as_slice(), device).unwrap();
|
||||
let idx_theta = arange
|
||||
.reshape((arange.elem_count(), 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
.reshape((arange.elem_count(), 1))
|
||||
.unwrap()
|
||||
.matmul(&theta.reshape((1, theta.elem_count())).unwrap())
|
||||
.unwrap();
|
||||
let shape = [1, seq_len, n_elem / 2, 1];
|
||||
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
|
||||
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
|
||||
let idx_theta_cos = idx_theta.cos().unwrap().reshape(&shape).unwrap();
|
||||
let idx_theta_sin = idx_theta.sin().unwrap().reshape(&shape).unwrap();
|
||||
let last_dim = idx_theta_cos.rank() - 1;
|
||||
Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim)?)
|
||||
Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim).unwrap())
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -418,19 +459,19 @@ async fn main() -> Result<()> {
|
||||
let device = if args.cpu {
|
||||
Device::Cpu
|
||||
} else {
|
||||
Device::new_cuda(0)?
|
||||
Device::new_cuda(0).unwrap()
|
||||
};
|
||||
let api = Api::new()?;
|
||||
let config = Config::config_7b();
|
||||
let cache = Cache::new(&device);
|
||||
let start = std::time::Instant::now();
|
||||
let (llama, tokenizer_filename) = if args.npy {
|
||||
println!("building the model (NPY)");
|
||||
(
|
||||
Llama::load_npy(&device, "/data/llama.npz", &cache, &config)?,
|
||||
Llama::load_npy(&device, "/data/llama.npz", &cache, &config).unwrap(),
|
||||
std::path::Path::new("llama-tokenizer.json").to_path_buf(),
|
||||
)
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||
let mut filenames = vec![];
|
||||
@ -444,20 +485,23 @@ async fn main() -> Result<()> {
|
||||
|
||||
println!("building the model (SF)");
|
||||
(
|
||||
Llama::load(&device, &filenames, &cache, &config)?,
|
||||
Llama::load(&device, &filenames, &cache, &config).unwrap(),
|
||||
tokenizer_filename,
|
||||
)
|
||||
};
|
||||
println!("Loaded in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename)
|
||||
.map_err(E::msg)
|
||||
.unwrap();
|
||||
let mut tokens = tokenizer
|
||||
.encode(START_PROMPT, true)
|
||||
.map_err(E::msg)?
|
||||
.map_err(E::msg)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("pre-computing the positional embeddings");
|
||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||
let freqs_cis = precompute_freqs_cis(&config, &device).unwrap();
|
||||
println!("starting the inference loop");
|
||||
let mut new_tokens = vec![];
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
||||
@ -465,18 +509,21 @@ async fn main() -> Result<()> {
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
let input = Tensor::new(ctxt, &device).unwrap();
|
||||
let logits = llama.forward(&input, &freqs_cis).unwrap();
|
||||
|
||||
let next_token = if let Some(temperature) = args.temperature {
|
||||
println!("Sampling with temperature {temperature:?}");
|
||||
let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
let prs = (&logits / temperature)
|
||||
.unwrap()
|
||||
.softmax(logits.rank() - 1)
|
||||
.unwrap();
|
||||
let logits_v: Vec<f32> = prs.to_vec1().unwrap();
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v).unwrap();
|
||||
|
||||
distr.sample(&mut rng) as u32
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
let logits_v: Vec<f32> = logits.to_vec1().unwrap();
|
||||
logits_v
|
||||
.iter()
|
||||
.enumerate()
|
||||
@ -491,7 +538,10 @@ async fn main() -> Result<()> {
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
@ -499,7 +549,7 @@ async fn main() -> Result<()> {
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg).unwrap()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user