mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Do not use the kv-cache on external key-value states. (#1054)
This commit is contained in:
@ -348,21 +348,21 @@ impl T5Attention {
|
||||
.contiguous()?;
|
||||
let mut k = k
|
||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
|
||||
if self.use_cache {
|
||||
if self.use_cache && key_value_states.is_none() {
|
||||
let _enter = self.span_cache.enter();
|
||||
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
|
||||
k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
|
||||
v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
|
||||
k = Tensor::cat(&[kv_cache_k, &k], 2)?;
|
||||
v = Tensor::cat(&[kv_cache_v, &v], 2)?;
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
};
|
||||
let k = k.contiguous()?;
|
||||
let v = v.contiguous()?;
|
||||
// TODO: Use flash_attn.
|
||||
let scores = {
|
||||
let _enter = self.span_mm.enter();
|
||||
|
Reference in New Issue
Block a user