Add a slice_set op. (#2193)

* Add a slice_set op.

* Add some testing.

* Add the dedicated kv-cache module.

* Derive debug and clone.

* Expose more kv-cache functions.

* Return the current data when appending.

* Use the new cache in the quantized phi3 model.
This commit is contained in:
Laurent Mazare
2024-05-18 15:58:18 +02:00
committed by GitHub
parent 349c3e806a
commit 01545f7303
6 changed files with 209 additions and 23 deletions

View File

@ -3,9 +3,7 @@ use std::collections::HashMap;
use candle::quantized::gguf_file;
use candle::quantized::QTensor;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Embedding, RmsNorm};
pub const MAX_SEQ_LEN: usize = 4096;
use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm};
#[derive(Debug, Clone)]
struct QLinear {
@ -70,7 +68,7 @@ struct LayerWeights {
cos: Tensor,
sin: Tensor,
neg_inf: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
kv_cache: KvCache,
span_attn: tracing::Span,
span_rot: tracing::Span,
}
@ -122,19 +120,7 @@ impl LayerWeights {
let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
let k = self.apply_rotary_emb(&k, index_pos)?;
let (k, v) = match &self.kv_cache {
None => (k.contiguous()?, v.contiguous()?),
Some((k_cache, v_cache)) => {
if index_pos == 0 {
(k.contiguous()?, v.contiguous()?)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?;
let v = Tensor::cat(&[v_cache, &v], 2)?;
(k.contiguous()?, v.contiguous()?)
}
}
};
self.kv_cache = Some((k.clone(), v.clone()));
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
@ -169,6 +155,7 @@ pub struct ModelWeights {
fn precomput_freqs_cis(
head_dim: usize,
max_seq_len: usize,
freq_base: f32,
device: &Device,
) -> Result<(Tensor, Tensor)> {
@ -177,9 +164,9 @@ fn precomput_freqs_cis(
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.reshape((max_seq_len, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
@ -188,6 +175,7 @@ fn precomput_freqs_cis(
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
batch_size: usize,
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
@ -202,16 +190,19 @@ impl ModelWeights {
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
let max_seq_len = md_get("phi3.context_length")?.to_u32()? as usize;
let head_dim = embedding_length / head_count;
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
let (cos, sin) = precomput_freqs_cis(rope_dim, max_seq_len, 10_000., device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?;
let output = QLinear::new(&ct, reader, "output", device)?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
@ -232,6 +223,12 @@ impl ModelWeights {
)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let kv_cache = KvCache::new(
2,
(batch_size, head_count_kv, max_seq_len, head_dim),
DType::F32,
device,
)?;
layers.push(LayerWeights {
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
@ -240,11 +237,11 @@ impl ModelWeights {
mlp,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,
head_dim,
cos: cos.clone(),
sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None,
kv_cache,
span_attn,
span_rot,
})