Add a KV cache to falcon. (#104)

This commit is contained in:
Laurent Mazare
2023-07-07 17:24:38 +01:00
committed by GitHub
parent 05ff1cff66
commit e923b3adc2
3 changed files with 80 additions and 43 deletions

View File

@ -340,8 +340,7 @@ impl CausalSelfAttention {
cache[block_idx] = Some((k.clone(), v.clone()))
}
let k_shape = k.shape();
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;