Add a clear cache function to the t5 model. (#919)

This commit is contained in:
Laurent Mazare
2023-09-21 09:01:06 +01:00
committed by GitHub
parent 7b26e513f1
commit c89b82b2d4

View File

@ -495,6 +495,10 @@ impl T5Attention {
let attn_output = self.o.forward(&attn_output)?; let attn_output = self.o.forward(&attn_output)?;
Ok((attn_output, position_bias)) Ok((attn_output, position_bias))
} }
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -530,6 +534,10 @@ impl T5LayerSelfAttention {
let ys = (xs + ys)?; let ys = (xs + ys)?;
Ok((ys, position_bias)) Ok((ys, position_bias))
} }
fn clear_kv_cache(&mut self) {
self.self_attention.clear_kv_cache()
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -568,6 +576,10 @@ impl T5LayerCrossAttention {
let ys = (hidden_states + ys)?; let ys = (hidden_states + ys)?;
Ok((ys, position_bias)) Ok((ys, position_bias))
} }
fn clear_kv_cache(&mut self) {
self.cross_attention.clear_kv_cache()
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -634,6 +646,11 @@ impl T5Block {
// TODO: clamp for f16? // TODO: clamp for f16?
Ok((xs, position_bias)) Ok((xs, position_bias))
} }
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache());
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -680,6 +697,10 @@ impl T5Stack {
} }
self.final_layer_norm.forward(&hidden_states) self.final_layer_norm.forward(&hidden_states)
} }
fn clear_kv_cache(&mut self) {
self.block.iter_mut().for_each(|b| b.clear_kv_cache())
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -709,6 +730,10 @@ impl T5EncoderModel {
pub fn device(&self) -> &Device { pub fn device(&self) -> &Device {
&self.device &self.device
} }
pub fn clear_kv_cache(&mut self) {
self.encoder.clear_kv_cache()
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -808,4 +833,9 @@ impl T5ForConditionalGeneration {
pub fn device(&self) -> &Device { pub fn device(&self) -> &Device {
&self.device &self.device
} }
pub fn clear_kv_cache(&mut self) {
self.encoder.clear_kv_cache();
self.decoder.clear_kv_cache();
}
} }