From 58c1e909d3ed07ce0bbc316bc80bcdd7ff109524 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sun, 22 Sep 2024 13:43:46 +0200 Subject: [PATCH] Handle contiguity + bugfix + use in mimi. --- candle-nn/src/kv_cache.rs | 16 ++++++++++------ candle-nn/tests/kv_cache.rs | 2 +- .../src/models/mimi/transformer.rs | 8 ++++---- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs index 59ff5648..9e860d61 100644 --- a/candle-nn/src/kv_cache.rs +++ b/candle-nn/src/kv_cache.rs @@ -224,23 +224,27 @@ impl RotatingCache { self.current_seq_len += seq_len; if seq_len >= self.max_seq_len { - let src = src.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?; - ad.slice_set(&src, self.dim, 0)?; + let to_copy = src + .narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)? + .contiguous()?; + ad.slice_set(&to_copy, self.dim, 0)?; self.offset = 0; // Here we return `src` rather than `ad` so that all the past can be used. - Ok(src) + Ok(src.clone()) } else { let rem_len = self.max_seq_len - self.offset; if seq_len <= rem_len { - ad.slice_set(src, self.dim, self.offset)?; + ad.slice_set(&src.contiguous()?, self.dim, self.offset)?; self.offset = (self.offset + seq_len) % self.max_seq_len; } else { // We have to make two copies here as we go over the boundary of the cache. if rem_len > 0 { - let src1 = src.narrow(self.dim, 0, rem_len)?; + let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?; ad.slice_set(&src1, self.dim, self.offset)?; } - let src2 = src.narrow(self.dim, rem_len, seq_len - rem_len)?; + let src2 = src + .narrow(self.dim, rem_len, seq_len - rem_len)? + .contiguous()?; ad.slice_set(&src2, self.dim, 0)?; self.offset = seq_len - rem_len; } diff --git a/candle-nn/tests/kv_cache.rs b/candle-nn/tests/kv_cache.rs index 2f70f3d4..88558d51 100644 --- a/candle-nn/tests/kv_cache.rs +++ b/candle-nn/tests/kv_cache.rs @@ -71,7 +71,7 @@ fn rotating_kv_cache() -> Result<()> { let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?; let data = cache.append(&t)?; - assert_eq!(data.to_vec1::()?, [3., 4., 5., 6., 7., 8.]); + assert_eq!(data.to_vec1::()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]); assert_eq!(cache.current_seq_len(), 22); assert_eq!(cache.offset(), 0); diff --git a/candle-transformers/src/models/mimi/transformer.rs b/candle-transformers/src/models/mimi/transformer.rs index de221274..0fa70792 100644 --- a/candle-transformers/src/models/mimi/transformer.rs +++ b/candle-transformers/src/models/mimi/transformer.rs @@ -127,7 +127,7 @@ pub struct StreamingMultiheadAttention { context: usize, neg_inf: Tensor, rope: Option>, - kv_cache: candle_nn::kv_cache::KvCache, + kv_cache: candle_nn::kv_cache::RotatingKvCache, pos: usize, use_flash_attn: bool, span: tracing::Span, @@ -153,7 +153,7 @@ impl StreamingMultiheadAttention { num_heads: cfg.num_heads, context: cfg.context, neg_inf, - kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len), + kv_cache: candle_nn::kv_cache::RotatingKvCache::new(2, cfg.context), pos: 0, use_flash_attn: false, span: tracing::span!(tracing::Level::TRACE, "mha"), @@ -236,7 +236,7 @@ impl StreamingMultiheadAttention { self.kv_cache.reset() } - pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) { + pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) { self.kv_cache = kv_cache } } @@ -582,7 +582,7 @@ impl StreamingTransformerLayer { self.self_attn.reset_kv_cache() } - pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) { + pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) { self.self_attn.set_kv_cache(kv_cache) } }