Do not use the kv-cache on external key-value states. (#1054)

This commit is contained in:
Laurent Mazare
2023-10-07 22:37:19 +01:00
committed by GitHub
parent 823fe23f9b
commit 2e5fb0b251
3 changed files with 14 additions and 15 deletions

View File

@ -348,21 +348,21 @@ impl T5Attention {
.contiguous()?;
let mut k = k
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;
let mut v = v
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
.transpose(1, 2)?;
if self.use_cache {
if self.use_cache && key_value_states.is_none() {
let _enter = self.span_cache.enter();
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
k = Tensor::cat(&[kv_cache_k, &k], 2)?;
v = Tensor::cat(&[kv_cache_v, &v], 2)?;
};
self.kv_cache = Some((k.clone(), v.clone()));
};
let k = k.contiguous()?;
let v = v.contiguous()?;
// TODO: Use flash_attn.
let scores = {
let _enter = self.span_mm.enter();