From 2e5fb0b2518aa7f7c666967fe4160462578cf8d0 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 7 Oct 2023 22:37:19 +0100 Subject: [PATCH] Do not use the kv-cache on external key-value states. (#1054) --- candle-examples/examples/quantized-t5/main.rs | 1 - candle-transformers/src/models/quantized_t5.rs | 14 +++++++------- candle-transformers/src/models/t5.rs | 14 +++++++------- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs index 2bc050ee..5a1cdf0c 100644 --- a/candle-examples/examples/quantized-t5/main.rs +++ b/candle-examples/examples/quantized-t5/main.rs @@ -153,7 +153,6 @@ fn main() -> Result<()> { let args = Args::parse(); let _guard = if args.tracing { - println!("tracing..."); let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); Some(guard) diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 5f08c67d..bf5797e9 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -348,21 +348,21 @@ impl T5Attention { .contiguous()?; let mut k = k .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? - .transpose(1, 2)? - .contiguous()?; + .transpose(1, 2)?; let mut v = v .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? - .transpose(1, 2)? - .contiguous()?; + .transpose(1, 2)?; - if self.use_cache { + if self.use_cache && key_value_states.is_none() { let _enter = self.span_cache.enter(); if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache { - k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?; - v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?; + k = Tensor::cat(&[kv_cache_k, &k], 2)?; + v = Tensor::cat(&[kv_cache_v, &v], 2)?; }; self.kv_cache = Some((k.clone(), v.clone())); }; + let k = k.contiguous()?; + let v = v.contiguous()?; // TODO: Use flash_attn. let scores = { let _enter = self.span_mm.enter(); diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 84704ca9..bdfabf28 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -348,21 +348,21 @@ impl T5Attention { .contiguous()?; let mut k = k .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? - .transpose(1, 2)? - .contiguous()?; + .transpose(1, 2)?; let mut v = v .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? - .transpose(1, 2)? - .contiguous()?; + .transpose(1, 2)?; - if self.use_cache { + if self.use_cache && key_value_states.is_none() { let _enter = self.span_cache.enter(); if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache { - k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?; - v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?; + k = Tensor::cat(&[kv_cache_k, &k], 2)?; + v = Tensor::cat(&[kv_cache_v, &v], 2)?; }; self.kv_cache = Some((k.clone(), v.clone())); }; + let k = k.contiguous()?; + let v = v.contiguous()?; // TODO: Use flash_attn. let scores = { let _enter = self.span_mm.enter();