Fix T5 kv cache (#899)

* Fix T5 kv cache

* Add argument for decoder prompt

* Fix range
This commit is contained in:
Juarez Bochi
2023-09-19 12:36:15 -07:00
committed by GitHub
parent d7e48234d4
commit 8696f64bae
2 changed files with 26 additions and 7 deletions

View File

@ -348,9 +348,14 @@ impl T5Attention {
None => (scores, None),
Some(relative_attention_bias) => {
// This only handles the bidirectional case.
let kv_len = k.dim(2)?;
let (q_start, q_end) = match self.use_cache {
true => ((kv_len - q_len) as u32, kv_len as u32),
false => (0_u32, kv_len as u32),
};
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
let max_exact = num_buckets / 2;
let relative_position = (0..q_len as u32)
let relative_position = (q_start..q_end)
.map(|i| {
(0..kv_len as u32)
.map(|j| {