More scaffolding, now need to implement matmul (for precompute_cos_sin to work).

This commit is contained in:
Nicolas Patry
2023-11-01 16:54:09 +01:00
parent 2d84c16fed
commit 492d164235
4 changed files with 108 additions and 46 deletions

View File

@ -2,7 +2,7 @@ use std::collections::HashMap;
use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096;
@ -181,28 +181,31 @@ pub struct ModelWeights {
span_output: tracing::Span,
}
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
fn precomput_freqs_cis(head_dim: usize, freq_base: f32, device: &Device) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let theta = Tensor::new(theta.as_slice(), device)?;
let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
let idx_theta = Tensor::new(range.as_slice(), device)?.reshape((MAX_SEQ_LEN, 1))?.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// TODO This change avoids allocating on Metal and then casting since allocating directly on
// CPU as f32 seems just as fast
// let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
// .to_dtype(DType::F32)?
// .reshape((MAX_SEQ_LEN, 1))?
// .matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok((cos, sin))
}
impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let cpu = &Device::Cpu;
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize, device: &Device) -> Result<Self> {
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
@ -276,7 +279,7 @@ impl ModelWeights {
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
@ -331,14 +334,14 @@ impl ModelWeights {
})
}
fn mask(&mut self, t: usize) -> Result<Tensor> {
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
let mask = Tensor::from_slice(&mask, (t, t), device)?;
self.masks.insert(t, mask.clone());
Ok(mask)
}
@ -346,7 +349,7 @@ impl ModelWeights {
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mask = self.mask(seq_len)?;
let mask = self.mask(seq_len, x.device())?;
let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?;
for layer in self.layers.iter_mut() {