mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add a kv-cache to the quantized llama example. (#466)
* Add a kv-cache to the quantized llama example. * Also print the prompt. * Bugfix in q6k dequantizing. * Another bugfix.
This commit is contained in:
@ -52,6 +52,7 @@ struct LayerWeights {
|
||||
head_dim: usize,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
@ -75,7 +76,7 @@ impl LayerWeights {
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward_attn(&self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.attention_wq.forward(x)?;
|
||||
let k = self.attention_wk.forward(x)?;
|
||||
@ -94,7 +95,15 @@ impl LayerWeights {
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
// TODO: KV cache.
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((k_cache, v_cache)) => {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
|
||||
(k, v)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
// If we start supporting MQA, we need to repeat the k and v tensors here.
|
||||
|
||||
@ -181,6 +190,7 @@ impl ModelWeights {
|
||||
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
Ok(Self {
|
||||
@ -209,7 +219,7 @@ impl ModelWeights {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mask = self.mask(seq_len)?;
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
for (_layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
for layer in self.layers.iter_mut() {
|
||||
let x = layer_in;
|
||||
let residual = &x;
|
||||
let x = layer.attention_norm.forward(&x)?;
|
||||
@ -302,8 +312,9 @@ fn main() -> anyhow::Result<()> {
|
||||
.to_vec();
|
||||
let mut index_pos = 0;
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
for _index in 0..args.sample_len {
|
||||
let context_size = tokens.len();
|
||||
print!("{prompt}");
|
||||
for index in 0..args.sample_len {
|
||||
let context_size = if index == 0 { tokens.len() } else { 1 };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
|
Reference in New Issue
Block a user