Add a kv-cache to the quantized llama example. (#466)

* Add a kv-cache to the quantized llama example.

* Also print the prompt.

* Bugfix in q6k dequantizing.

* Another bugfix.
This commit is contained in:
Laurent Mazare
2023-08-16 14:28:42 +01:00
committed by GitHub
parent 3071134788
commit a9101700b6
2 changed files with 20 additions and 9 deletions

View File

@ -52,6 +52,7 @@ struct LayerWeights {
head_dim: usize,
cos: Tensor,
sin: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
@ -75,7 +76,7 @@ impl LayerWeights {
Ok(rope)
}
fn forward_attn(&self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
@ -94,7 +95,15 @@ impl LayerWeights {
let q = self.apply_rotary_emb(&q, index_pos)?;
let k = self.apply_rotary_emb(&k, index_pos)?;
// TODO: KV cache.
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((k_cache, v_cache)) => {
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
(k, v)
}
};
self.kv_cache = Some((k.clone(), v.clone()));
// If we start supporting MQA, we need to repeat the k and v tensors here.
@ -181,6 +190,7 @@ impl ModelWeights {
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(),
sin: sin.clone(),
kv_cache: None,
})
}
Ok(Self {
@ -209,7 +219,7 @@ impl ModelWeights {
let (_b_sz, seq_len) = x.dims2()?;
let mask = self.mask(seq_len)?;
let mut layer_in = self.tok_embeddings.forward(x)?;
for (_layer_idx, layer) in self.layers.iter().enumerate() {
for layer in self.layers.iter_mut() {
let x = layer_in;
let residual = &x;
let x = layer.attention_norm.forward(&x)?;
@ -302,8 +312,9 @@ fn main() -> anyhow::Result<()> {
.to_vec();
let mut index_pos = 0;
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
for _index in 0..args.sample_len {
let context_size = tokens.len();
print!("{prompt}");
for index in 0..args.sample_len {
let context_size = if index == 0 { tokens.len() } else { 1 };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos)?;