From 1fcac4afede215d44c4bf97c8b8c5bad06fcba09 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 26 Sep 2023 05:41:07 +0100 Subject: [PATCH] Expose a function to clear the KV cache on mixformers. (#964) --- candle-transformers/src/models/mixformer.rs | 12 ++++++++++++ .../src/models/quantized_mixformer.rs | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index e945cd51..b2fa2860 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -287,6 +287,10 @@ impl MHA { .flatten_from(D::Minus2)?; attn_output.apply(&self.out_proj) } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } } #[derive(Debug)] @@ -318,6 +322,10 @@ impl ParallelBlock { let feed_forward_hidden_states = self.mlp.forward(&xs)?; attn_outputs + feed_forward_hidden_states + residual } + + fn clear_kv_cache(&mut self) { + self.mixer.clear_kv_cache() + } } #[derive(Debug)] @@ -360,4 +368,8 @@ impl MixFormerSequentialForCausalLM { } xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) } + + pub fn clear_kv_cache(&mut self) { + self.blocks.iter_mut().for_each(|b| b.clear_kv_cache()) + } } diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index 4ace2045..e458cf5c 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -268,6 +268,10 @@ impl MHA { .flatten_from(D::Minus2)?; attn_output.apply(&self.out_proj) } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } } #[derive(Debug)] @@ -299,6 +303,10 @@ impl ParallelBlock { let feed_forward_hidden_states = self.mlp.forward(&xs)?; attn_outputs + feed_forward_hidden_states + residual } + + fn clear_kv_cache(&mut self) { + self.mixer.clear_kv_cache() + } } #[derive(Debug)] @@ -341,4 +349,8 @@ impl MixFormerSequentialForCausalLM { } xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) } + + pub fn clear_kv_cache(&mut self) { + self.blocks.iter_mut().for_each(|b| b.clear_kv_cache()) + } }