mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add some KV cache to blip. (#1150)
* Add some KV cache to blip. * Mention BLIP in the readme.
This commit is contained in:
@ -296,7 +296,7 @@ impl BlipForConditionalGeneration {
|
||||
&self.vision_model
|
||||
}
|
||||
|
||||
pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel {
|
||||
&self.text_decoder
|
||||
pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
|
||||
&mut self.text_decoder
|
||||
}
|
||||
}
|
||||
|
@ -44,13 +44,10 @@ impl TextEmbeddings {
|
||||
position_ids,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextEmbeddings {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
let seq_len = xs.dim(1)?;
|
||||
// Use past_key_values_length if we add a kv cache.
|
||||
let position_ids = self.position_ids.narrow(1, 0, seq_len)?;
|
||||
let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?;
|
||||
let embeddings = self.word_embedddings.forward(xs)?;
|
||||
let position_embeddings = self.position_embeddings.forward(&position_ids)?;
|
||||
(embeddings + position_embeddings)?.apply(&self.layer_norm)
|
||||
@ -65,6 +62,7 @@ struct TextSelfAttention {
|
||||
attention_head_size: usize,
|
||||
num_attention_heads: usize,
|
||||
attention_scale: f64,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl TextSelfAttention {
|
||||
@ -88,6 +86,7 @@ impl TextSelfAttention {
|
||||
attention_head_size,
|
||||
num_attention_heads,
|
||||
attention_scale,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
@ -102,8 +101,12 @@ impl TextSelfAttention {
|
||||
.permute((0, 2, 1, 3))
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
@ -115,7 +118,15 @@ impl TextSelfAttention {
|
||||
None => {
|
||||
let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
|
||||
let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
|
||||
// TODO: kv cache
|
||||
let (key, value) = match &self.kv_cache {
|
||||
None => (key, value),
|
||||
Some((prev_key, prev_value)) => {
|
||||
let key = Tensor::cat(&[prev_key, &key], 2)?;
|
||||
let value = Tensor::cat(&[prev_value, &value], 2)?;
|
||||
(key, value)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key.clone(), value.clone()));
|
||||
(key, value)
|
||||
}
|
||||
Some(xs) => {
|
||||
@ -172,8 +183,12 @@ impl TextAttention {
|
||||
Ok(Self { self_, output })
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_.reset_kv_cache()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
@ -251,14 +266,21 @@ impl TextLayer {
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.attention.reset_kv_cache();
|
||||
if let Some(ca) = &mut self.cross_attention {
|
||||
ca.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
|
||||
let attention_output = match &self.cross_attention {
|
||||
let attention_output = match &mut self.cross_attention {
|
||||
Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
|
||||
None => candle::bail!("expected some cross-attn"),
|
||||
};
|
||||
@ -283,14 +305,18 @@ impl TextEncoder {
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
|
||||
}
|
||||
Ok(xs)
|
||||
@ -389,6 +415,7 @@ impl Module for TextOnlyMLMHead {
|
||||
struct TextModel {
|
||||
embeddings: TextEmbeddings,
|
||||
encoder: TextEncoder,
|
||||
past_kv_len: usize,
|
||||
// We do not need the pooler for caption generation
|
||||
}
|
||||
|
||||
@ -399,22 +426,30 @@ impl TextModel {
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
past_kv_len: 0,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let embedding_output = self.embeddings.forward(input_ids)?;
|
||||
let (_b_sz, seq_len) = input_ids.dims2()?;
|
||||
let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?;
|
||||
let sequence_output =
|
||||
self.encoder
|
||||
.forward(&embedding_output, encoder_hidden_states, attention_mask)?;
|
||||
self.past_kv_len += seq_len;
|
||||
// We're interested in the sequence-output rather than the pooled-output.
|
||||
Ok(sequence_output)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.past_kv_len = 0;
|
||||
self.encoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -430,7 +465,11 @@ impl TextLMHeadModel {
|
||||
Ok(Self { bert, cls })
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let seq_len = input_ids.dim(1)?;
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
@ -441,4 +480,8 @@ impl TextLMHeadModel {
|
||||
// return_logits is false so we don't discard the last sequence element.
|
||||
Ok(prediction_scores)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.bert.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user