mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add some KV cache to blip. (#1150)
* Add some KV cache to blip. * Mention BLIP in the readme.
This commit is contained in:
@ -99,6 +99,8 @@ We also provide a some command line based examples using state of the art models
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
|
||||
Run them using commands like:
|
||||
```
|
||||
@ -163,8 +165,11 @@ If you have an addition to this list, please submit a pull request.
|
||||
- T5.
|
||||
- Bert.
|
||||
- Whisper (multi-lingual support).
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Text to image.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||
- yolo-v3, yolo-v8.
|
||||
|
@ -86,17 +86,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let config = blip::Config::image_captioning_large();
|
||||
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let vision_model = model.vision_model();
|
||||
let text_decoder = model.text_decoder();
|
||||
let mut model = blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
println!("model built");
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
let image_embeds = image.unsqueeze(0)?.apply(vision_model)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
|
||||
let mut token_ids = vec![30522u32];
|
||||
for _index in 0..1000 {
|
||||
let input_ids = Tensor::new(token_ids.as_slice(), &device)?.broadcast_left(1)?;
|
||||
let logits = text_decoder.forward(&input_ids, &image_embeds)?;
|
||||
for index in 0..1000 {
|
||||
let context_size = if index > 0 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.text_decoder().forward(&input_ids, &image_embeds)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
|
@ -296,7 +296,7 @@ impl BlipForConditionalGeneration {
|
||||
&self.vision_model
|
||||
}
|
||||
|
||||
pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel {
|
||||
&self.text_decoder
|
||||
pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
|
||||
&mut self.text_decoder
|
||||
}
|
||||
}
|
||||
|
@ -44,13 +44,10 @@ impl TextEmbeddings {
|
||||
position_ids,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextEmbeddings {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
let seq_len = xs.dim(1)?;
|
||||
// Use past_key_values_length if we add a kv cache.
|
||||
let position_ids = self.position_ids.narrow(1, 0, seq_len)?;
|
||||
let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?;
|
||||
let embeddings = self.word_embedddings.forward(xs)?;
|
||||
let position_embeddings = self.position_embeddings.forward(&position_ids)?;
|
||||
(embeddings + position_embeddings)?.apply(&self.layer_norm)
|
||||
@ -65,6 +62,7 @@ struct TextSelfAttention {
|
||||
attention_head_size: usize,
|
||||
num_attention_heads: usize,
|
||||
attention_scale: f64,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl TextSelfAttention {
|
||||
@ -88,6 +86,7 @@ impl TextSelfAttention {
|
||||
attention_head_size,
|
||||
num_attention_heads,
|
||||
attention_scale,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
@ -102,8 +101,12 @@ impl TextSelfAttention {
|
||||
.permute((0, 2, 1, 3))
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
@ -115,7 +118,15 @@ impl TextSelfAttention {
|
||||
None => {
|
||||
let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
|
||||
let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
|
||||
// TODO: kv cache
|
||||
let (key, value) = match &self.kv_cache {
|
||||
None => (key, value),
|
||||
Some((prev_key, prev_value)) => {
|
||||
let key = Tensor::cat(&[prev_key, &key], 2)?;
|
||||
let value = Tensor::cat(&[prev_value, &value], 2)?;
|
||||
(key, value)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key.clone(), value.clone()));
|
||||
(key, value)
|
||||
}
|
||||
Some(xs) => {
|
||||
@ -172,8 +183,12 @@ impl TextAttention {
|
||||
Ok(Self { self_, output })
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_.reset_kv_cache()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
@ -251,14 +266,21 @@ impl TextLayer {
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.attention.reset_kv_cache();
|
||||
if let Some(ca) = &mut self.cross_attention {
|
||||
ca.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
|
||||
let attention_output = match &self.cross_attention {
|
||||
let attention_output = match &mut self.cross_attention {
|
||||
Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
|
||||
None => candle::bail!("expected some cross-attn"),
|
||||
};
|
||||
@ -283,14 +305,18 @@ impl TextEncoder {
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
|
||||
}
|
||||
Ok(xs)
|
||||
@ -389,6 +415,7 @@ impl Module for TextOnlyMLMHead {
|
||||
struct TextModel {
|
||||
embeddings: TextEmbeddings,
|
||||
encoder: TextEncoder,
|
||||
past_kv_len: usize,
|
||||
// We do not need the pooler for caption generation
|
||||
}
|
||||
|
||||
@ -399,22 +426,30 @@ impl TextModel {
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
past_kv_len: 0,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let embedding_output = self.embeddings.forward(input_ids)?;
|
||||
let (_b_sz, seq_len) = input_ids.dims2()?;
|
||||
let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?;
|
||||
let sequence_output =
|
||||
self.encoder
|
||||
.forward(&embedding_output, encoder_hidden_states, attention_mask)?;
|
||||
self.past_kv_len += seq_len;
|
||||
// We're interested in the sequence-output rather than the pooled-output.
|
||||
Ok(sequence_output)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.past_kv_len = 0;
|
||||
self.encoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -430,7 +465,11 @@ impl TextLMHeadModel {
|
||||
Ok(Self { bert, cls })
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let seq_len = input_ids.dim(1)?;
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
@ -441,4 +480,8 @@ impl TextLMHeadModel {
|
||||
// return_logits is false so we don't discard the last sequence element.
|
||||
Ok(prediction_scores)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.bert.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user