From df2f89b6cf897305a566cb08446dd4522d42919a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 22 Oct 2023 09:44:48 +0100 Subject: [PATCH] Add some KV cache to blip. (#1150) * Add some KV cache to blip. * Mention BLIP in the readme. --- README.md | 9 ++- candle-examples/examples/blip/main.rs | 14 ++-- candle-transformers/src/models/blip.rs | 4 +- candle-transformers/src/models/blip_text.rs | 73 ++++++++++++++++----- 4 files changed, 74 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 03c2a1f5..09f15885 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,8 @@ We also provide a some command line based examples using state of the art models - [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained using self-supervision (can be used for imagenet classification, depth evaluation, segmentation). +- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to + generate captions for an image. Run them using commands like: ``` @@ -163,8 +165,11 @@ If you have an addition to this list, please submit a pull request. - T5. - Bert. - Whisper (multi-lingual support). - - Stable Diffusion v1.5, v2.1, XL v1.0. - - Wurstchen v2. + - Text to image. + - Stable Diffusion v1.5, v2.1, XL v1.0. + - Wurstchen v2. + - Image to text. + - BLIP. - Computer Vision Models. - DINOv2, ConvMixer, EfficientNet, ResNet, ViT. - yolo-v3, yolo-v8. diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs index 54abab8c..81c01482 100644 --- a/candle-examples/examples/blip/main.rs +++ b/candle-examples/examples/blip/main.rs @@ -86,17 +86,17 @@ pub fn main() -> anyhow::Result<()> { let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; let config = blip::Config::image_captioning_large(); - let model = blip::BlipForConditionalGeneration::new(&config, vb)?; - let vision_model = model.vision_model(); - let text_decoder = model.text_decoder(); + let mut model = blip::BlipForConditionalGeneration::new(&config, vb)?; println!("model built"); // TODO: Maybe add support for the conditional prompt. - let image_embeds = image.unsqueeze(0)?.apply(vision_model)?; + let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?; let mut token_ids = vec![30522u32]; - for _index in 0..1000 { - let input_ids = Tensor::new(token_ids.as_slice(), &device)?.broadcast_left(1)?; - let logits = text_decoder.forward(&input_ids, &image_embeds)?; + for index in 0..1000 { + let context_size = if index > 0 { 1 } else { token_ids.len() }; + let start_pos = token_ids.len().saturating_sub(context_size); + let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; + let logits = model.text_decoder().forward(&input_ids, &image_embeds)?; let logits = logits.squeeze(0)?; let logits = logits.get(logits.dim(0)? - 1)?; let token = logits_processor.sample(&logits)?; diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs index daa96926..7c204d21 100644 --- a/candle-transformers/src/models/blip.rs +++ b/candle-transformers/src/models/blip.rs @@ -296,7 +296,7 @@ impl BlipForConditionalGeneration { &self.vision_model } - pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel { - &self.text_decoder + pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel { + &mut self.text_decoder } } diff --git a/candle-transformers/src/models/blip_text.rs b/candle-transformers/src/models/blip_text.rs index f1a38f11..3a519e4c 100644 --- a/candle-transformers/src/models/blip_text.rs +++ b/candle-transformers/src/models/blip_text.rs @@ -44,13 +44,10 @@ impl TextEmbeddings { position_ids, }) } -} -impl Module for TextEmbeddings { - fn forward(&self, xs: &Tensor) -> Result { + fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result { let seq_len = xs.dim(1)?; - // Use past_key_values_length if we add a kv cache. - let position_ids = self.position_ids.narrow(1, 0, seq_len)?; + let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?; let embeddings = self.word_embedddings.forward(xs)?; let position_embeddings = self.position_embeddings.forward(&position_ids)?; (embeddings + position_embeddings)?.apply(&self.layer_norm) @@ -65,6 +62,7 @@ struct TextSelfAttention { attention_head_size: usize, num_attention_heads: usize, attention_scale: f64, + kv_cache: Option<(Tensor, Tensor)>, } impl TextSelfAttention { @@ -88,6 +86,7 @@ impl TextSelfAttention { attention_head_size, num_attention_heads, attention_scale, + kv_cache: None, }) } @@ -102,8 +101,12 @@ impl TextSelfAttention { .permute((0, 2, 1, 3)) } + fn reset_kv_cache(&mut self) { + self.kv_cache = None + } + fn forward( - &self, + &mut self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>, attention_mask: Option<&Tensor>, @@ -115,7 +118,15 @@ impl TextSelfAttention { None => { let key = self.transpose_for_scores(&self.key.forward(xs)?)?; let value = self.transpose_for_scores(&self.value.forward(xs)?)?; - // TODO: kv cache + let (key, value) = match &self.kv_cache { + None => (key, value), + Some((prev_key, prev_value)) => { + let key = Tensor::cat(&[prev_key, &key], 2)?; + let value = Tensor::cat(&[prev_value, &value], 2)?; + (key, value) + } + }; + self.kv_cache = Some((key.clone(), value.clone())); (key, value) } Some(xs) => { @@ -172,8 +183,12 @@ impl TextAttention { Ok(Self { self_, output }) } + fn reset_kv_cache(&mut self) { + self.self_.reset_kv_cache() + } + fn forward( - &self, + &mut self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>, attention_mask: Option<&Tensor>, @@ -251,14 +266,21 @@ impl TextLayer { }) } + fn reset_kv_cache(&mut self) { + self.attention.reset_kv_cache(); + if let Some(ca) = &mut self.cross_attention { + ca.reset_kv_cache() + } + } + fn forward( - &self, + &mut self, xs: &Tensor, encoder_hidden_states: &Tensor, attention_mask: &Tensor, ) -> Result { let attention_output = self.attention.forward(xs, None, Some(attention_mask))?; - let attention_output = match &self.cross_attention { + let attention_output = match &mut self.cross_attention { Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?, None => candle::bail!("expected some cross-attn"), }; @@ -283,14 +305,18 @@ impl TextEncoder { Ok(Self { layers }) } + fn reset_kv_cache(&mut self) { + self.layers.iter_mut().for_each(|l| l.reset_kv_cache()) + } + fn forward( - &self, + &mut self, xs: &Tensor, encoder_hidden_states: &Tensor, attention_mask: &Tensor, ) -> Result { let mut xs = xs.clone(); - for layer in self.layers.iter() { + for layer in self.layers.iter_mut() { xs = layer.forward(&xs, encoder_hidden_states, attention_mask)? } Ok(xs) @@ -389,6 +415,7 @@ impl Module for TextOnlyMLMHead { struct TextModel { embeddings: TextEmbeddings, encoder: TextEncoder, + past_kv_len: usize, // We do not need the pooler for caption generation } @@ -399,22 +426,30 @@ impl TextModel { Ok(Self { embeddings, encoder, + past_kv_len: 0, }) } fn forward( - &self, + &mut self, input_ids: &Tensor, encoder_hidden_states: &Tensor, attention_mask: &Tensor, ) -> Result { - let embedding_output = self.embeddings.forward(input_ids)?; + let (_b_sz, seq_len) = input_ids.dims2()?; + let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?; let sequence_output = self.encoder .forward(&embedding_output, encoder_hidden_states, attention_mask)?; + self.past_kv_len += seq_len; // We're interested in the sequence-output rather than the pooled-output. Ok(sequence_output) } + + fn reset_kv_cache(&mut self) { + self.past_kv_len = 0; + self.encoder.reset_kv_cache(); + } } #[derive(Debug, Clone)] @@ -430,7 +465,11 @@ impl TextLMHeadModel { Ok(Self { bert, cls }) } - pub fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result { + pub fn forward( + &mut self, + input_ids: &Tensor, + encoder_hidden_states: &Tensor, + ) -> Result { let seq_len = input_ids.dim(1)?; let mask: Vec<_> = (0..seq_len) .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) @@ -441,4 +480,8 @@ impl TextLMHeadModel { // return_logits is false so we don't discard the last sequence element. Ok(prediction_scores) } + + pub fn reset_kv_cache(&mut self) { + self.bert.reset_kv_cache() + } }