mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add a clear cache function to the t5 model. (#919)
This commit is contained in:
@ -495,6 +495,10 @@ impl T5Attention {
|
|||||||
let attn_output = self.o.forward(&attn_output)?;
|
let attn_output = self.o.forward(&attn_output)?;
|
||||||
Ok((attn_output, position_bias))
|
Ok((attn_output, position_bias))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -530,6 +534,10 @@ impl T5LayerSelfAttention {
|
|||||||
let ys = (xs + ys)?;
|
let ys = (xs + ys)?;
|
||||||
Ok((ys, position_bias))
|
Ok((ys, position_bias))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attention.clear_kv_cache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -568,6 +576,10 @@ impl T5LayerCrossAttention {
|
|||||||
let ys = (hidden_states + ys)?;
|
let ys = (hidden_states + ys)?;
|
||||||
Ok((ys, position_bias))
|
Ok((ys, position_bias))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.cross_attention.clear_kv_cache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -634,6 +646,11 @@ impl T5Block {
|
|||||||
// TODO: clamp for f16?
|
// TODO: clamp for f16?
|
||||||
Ok((xs, position_bias))
|
Ok((xs, position_bias))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache();
|
||||||
|
self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -680,6 +697,10 @@ impl T5Stack {
|
|||||||
}
|
}
|
||||||
self.final_layer_norm.forward(&hidden_states)
|
self.final_layer_norm.forward(&hidden_states)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.block.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -709,6 +730,10 @@ impl T5EncoderModel {
|
|||||||
pub fn device(&self) -> &Device {
|
pub fn device(&self) -> &Device {
|
||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.encoder.clear_kv_cache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -808,4 +833,9 @@ impl T5ForConditionalGeneration {
|
|||||||
pub fn device(&self) -> &Device {
|
pub fn device(&self) -> &Device {
|
||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.encoder.clear_kv_cache();
|
||||||
|
self.decoder.clear_kv_cache();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user