Add a KV cache to marian decoding. (#1226)

This commit is contained in:
Laurent Mazare
2023-10-31 09:47:44 +01:00
committed by GitHub
parent 7d0202710b
commit c12ad45562
3 changed files with 55 additions and 24 deletions

View File

@ -149,6 +149,6 @@ pub fn main() -> anyhow::Result<()> {
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}"); print!("{rest}");
} }
println!();
Ok(()) Ok(())
} }

View File

@ -8,6 +8,7 @@ use anyhow::Error as E;
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use candle::{DType, Tensor}; use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::models::marian; use candle_transformers::models::marian;
@ -87,6 +88,7 @@ pub fn main() -> anyhow::Result<()> {
}; };
Tokenizer::from_file(&tokenizer).map_err(E::msg)? Tokenizer::from_file(&tokenizer).map_err(E::msg)?
}; };
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let vb = { let vb = {
@ -107,7 +109,7 @@ pub fn main() -> anyhow::Result<()> {
}; };
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? } unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
}; };
let model = marian::MTModel::new(&config, vb)?; let mut model = marian::MTModel::new(&config, vb)?;
let mut logits_processor = let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None); candle_transformers::generation::LogitsProcessor::new(1337, None, None);
@ -125,23 +127,26 @@ pub fn main() -> anyhow::Result<()> {
let mut token_ids = vec![config.decoder_start_token_id]; let mut token_ids = vec![config.decoder_start_token_id];
for index in 0..1000 { for index in 0..1000 {
// TODO: Add a kv cache. let context_size = if index >= 1 { 1 } else { token_ids.len() };
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size); let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs)?; let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?; let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?; let token = logits_processor.sample(&logits)?;
token_ids.push(token); token_ids.push(token);
println!("{token}"); if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
if token == config.eos_token_id || token == config.forced_eos_token_id { if token == config.eos_token_id || token == config.forced_eos_token_id {
break; break;
} }
} }
println!( if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
"{}", print!("{rest}");
tokenizer_dec.decode(&token_ids, true).map_err(E::msg)? }
); println!();
Ok(()) Ok(())
} }

View File

@ -126,6 +126,8 @@ struct Attention {
scaling: f64, scaling: f64,
num_heads: usize, num_heads: usize,
head_dim: usize, head_dim: usize,
kv_cache: Option<(Tensor, Tensor)>,
is_decoder: bool,
} }
impl Attention { impl Attention {
@ -150,6 +152,8 @@ impl Attention {
scaling, scaling,
num_heads, num_heads,
head_dim, head_dim,
kv_cache: None,
is_decoder,
}) })
} }
@ -161,7 +165,7 @@ impl Attention {
} }
fn forward( fn forward(
&self, &mut self,
xs: &Tensor, xs: &Tensor,
kv_states: Option<&Tensor>, kv_states: Option<&Tensor>,
attn_mask: Option<&Tensor>, attn_mask: Option<&Tensor>,
@ -173,8 +177,21 @@ impl Attention {
None => { None => {
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?; let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?; let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
if self.is_decoder {
let kv_states = match &self.kv_cache {
None => (key_states, value_states),
Some((p_key_states, p_value_states)) => {
let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
(key_states, value_states) (key_states, value_states)
} }
};
self.kv_cache = Some(kv_states.clone());
kv_states
} else {
(key_states, value_states)
}
}
Some(kv_states) => { Some(kv_states) => {
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?; let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?; let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
@ -198,6 +215,10 @@ impl Attention {
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))? .reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
.apply(&self.out_proj) .apply(&self.out_proj)
} }
fn reset_kv_cache(&mut self) {
self.kv_cache = None
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -227,7 +248,7 @@ impl EncoderLayer {
}) })
} }
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
let residual = xs; let residual = xs;
let xs = (self.self_attn.forward(xs, None, None)? + residual)? let xs = (self.self_attn.forward(xs, None, None)? + residual)?
.apply(&self.self_attn_layer_norm)?; .apply(&self.self_attn_layer_norm)?;
@ -275,7 +296,7 @@ impl DecoderLayer {
} }
fn forward( fn forward(
&self, &mut self,
xs: &Tensor, xs: &Tensor,
encoder_xs: Option<&Tensor>, encoder_xs: Option<&Tensor>,
attn_mask: &Tensor, attn_mask: &Tensor,
@ -331,7 +352,7 @@ impl Encoder {
}) })
} }
pub fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> { pub fn forward(&mut self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
let xs = xs.apply(&self.embed_tokens)?; let xs = xs.apply(&self.embed_tokens)?;
let xs = match self.embed_scale { let xs = match self.embed_scale {
None => xs, None => xs,
@ -342,7 +363,7 @@ impl Encoder {
.forward(&xs, past_kv_len)? .forward(&xs, past_kv_len)?
.unsqueeze(0)?; .unsqueeze(0)?;
let mut xs = xs.broadcast_add(&embed_pos)?; let mut xs = xs.broadcast_add(&embed_pos)?;
for layer in self.layers.iter() { for layer in self.layers.iter_mut() {
xs = layer.forward(&xs)? xs = layer.forward(&xs)?
} }
Ok(xs) Ok(xs)
@ -380,7 +401,7 @@ impl Decoder {
} }
pub fn forward( pub fn forward(
&self, &mut self,
xs: &Tensor, xs: &Tensor,
encoder_xs: Option<&Tensor>, encoder_xs: Option<&Tensor>,
past_kv_len: usize, past_kv_len: usize,
@ -396,7 +417,7 @@ impl Decoder {
.forward(&xs, past_kv_len)? .forward(&xs, past_kv_len)?
.unsqueeze(0)?; .unsqueeze(0)?;
let mut xs = xs.broadcast_add(&embed_pos)?; let mut xs = xs.broadcast_add(&embed_pos)?;
for layer in self.layers.iter() { for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, encoder_xs, attn_mask)?; xs = layer.forward(&xs, encoder_xs, attn_mask)?;
} }
Ok(xs) Ok(xs)
@ -443,15 +464,20 @@ impl MTModel {
}) })
} }
pub fn encoder(&self) -> &Encoder { pub fn encoder(&mut self) -> &mut Encoder {
&self.model.encoder &mut self.model.encoder
} }
pub fn decoder(&self) -> &Decoder { pub fn decoder(&mut self) -> &mut Decoder {
&self.model.decoder &mut self.model.decoder
} }
pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> { pub fn decode(
&mut self,
xs: &Tensor,
encoder_xs: &Tensor,
past_kv_len: usize,
) -> Result<Tensor> {
let seq_len = xs.dim(1)?; let seq_len = xs.dim(1)?;
let mask: Vec<_> = (0..seq_len) let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
@ -459,7 +485,7 @@ impl MTModel {
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?; let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
self.model self.model
.decoder .decoder
.forward(xs, Some(encoder_xs), 0, &mask)? .forward(xs, Some(encoder_xs), past_kv_len, &mask)?
.apply(&self.lm_head)? .apply(&self.lm_head)?
.broadcast_add(&self.final_logits_bias) .broadcast_add(&self.final_logits_bias)
} }