diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 2b71fcda..94cf5233 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -495,6 +495,10 @@ impl T5Attention { let attn_output = self.o.forward(&attn_output)?; Ok((attn_output, position_bias)) } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } } #[derive(Debug)] @@ -530,6 +534,10 @@ impl T5LayerSelfAttention { let ys = (xs + ys)?; Ok((ys, position_bias)) } + + fn clear_kv_cache(&mut self) { + self.self_attention.clear_kv_cache() + } } #[derive(Debug)] @@ -568,6 +576,10 @@ impl T5LayerCrossAttention { let ys = (hidden_states + ys)?; Ok((ys, position_bias)) } + + fn clear_kv_cache(&mut self) { + self.cross_attention.clear_kv_cache() + } } #[derive(Debug)] @@ -634,6 +646,11 @@ impl T5Block { // TODO: clamp for f16? Ok((xs, position_bias)) } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache()); + } } #[derive(Debug)] @@ -680,6 +697,10 @@ impl T5Stack { } self.final_layer_norm.forward(&hidden_states) } + + fn clear_kv_cache(&mut self) { + self.block.iter_mut().for_each(|b| b.clear_kv_cache()) + } } #[derive(Debug)] @@ -709,6 +730,10 @@ impl T5EncoderModel { pub fn device(&self) -> &Device { &self.device } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache() + } } #[derive(Debug)] @@ -808,4 +833,9 @@ impl T5ForConditionalGeneration { pub fn device(&self) -> &Device { &self.device } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache(); + self.decoder.clear_kv_cache(); + } }