mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add a KV cache to marian decoding. (#1226)
This commit is contained in:
@ -149,6 +149,6 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
print!("{rest}");
|
print!("{rest}");
|
||||||
}
|
}
|
||||||
|
println!();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ use anyhow::Error as E;
|
|||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle::{DType, Tensor};
|
use candle::{DType, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::models::marian;
|
use candle_transformers::models::marian;
|
||||||
|
|
||||||
@ -87,6 +88,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
};
|
};
|
||||||
|
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vb = {
|
let vb = {
|
||||||
@ -107,7 +109,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||||
};
|
};
|
||||||
let model = marian::MTModel::new(&config, vb)?;
|
let mut model = marian::MTModel::new(&config, vb)?;
|
||||||
|
|
||||||
let mut logits_processor =
|
let mut logits_processor =
|
||||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||||
@ -125,23 +127,26 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let mut token_ids = vec![config.decoder_start_token_id];
|
let mut token_ids = vec![config.decoder_start_token_id];
|
||||||
for index in 0..1000 {
|
for index in 0..1000 {
|
||||||
// TODO: Add a kv cache.
|
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||||
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
|
|
||||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||||
let logits = model.decode(&input_ids, &encoder_xs)?;
|
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||||
let token = logits_processor.sample(&logits)?;
|
let token = logits_processor.sample(&logits)?;
|
||||||
token_ids.push(token);
|
token_ids.push(token);
|
||||||
println!("{token}");
|
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||||
|
use std::io::Write;
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
println!(
|
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||||
"{}",
|
print!("{rest}");
|
||||||
tokenizer_dec.decode(&token_ids, true).map_err(E::msg)?
|
}
|
||||||
);
|
println!();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -126,6 +126,8 @@ struct Attention {
|
|||||||
scaling: f64,
|
scaling: f64,
|
||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
is_decoder: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
@ -150,6 +152,8 @@ impl Attention {
|
|||||||
scaling,
|
scaling,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
|
kv_cache: None,
|
||||||
|
is_decoder,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,7 +165,7 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
kv_states: Option<&Tensor>,
|
kv_states: Option<&Tensor>,
|
||||||
attn_mask: Option<&Tensor>,
|
attn_mask: Option<&Tensor>,
|
||||||
@ -173,8 +177,21 @@ impl Attention {
|
|||||||
None => {
|
None => {
|
||||||
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
|
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
|
||||||
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
|
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
|
||||||
|
if self.is_decoder {
|
||||||
|
let kv_states = match &self.kv_cache {
|
||||||
|
None => (key_states, value_states),
|
||||||
|
Some((p_key_states, p_value_states)) => {
|
||||||
|
let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
|
||||||
|
let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
|
||||||
(key_states, value_states)
|
(key_states, value_states)
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some(kv_states.clone());
|
||||||
|
kv_states
|
||||||
|
} else {
|
||||||
|
(key_states, value_states)
|
||||||
|
}
|
||||||
|
}
|
||||||
Some(kv_states) => {
|
Some(kv_states) => {
|
||||||
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
|
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
|
||||||
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
|
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
|
||||||
@ -198,6 +215,10 @@ impl Attention {
|
|||||||
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
|
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
|
||||||
.apply(&self.out_proj)
|
.apply(&self.out_proj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reset_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -227,7 +248,7 @@ impl EncoderLayer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
let xs = (self.self_attn.forward(xs, None, None)? + residual)?
|
let xs = (self.self_attn.forward(xs, None, None)? + residual)?
|
||||||
.apply(&self.self_attn_layer_norm)?;
|
.apply(&self.self_attn_layer_norm)?;
|
||||||
@ -275,7 +296,7 @@ impl DecoderLayer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
encoder_xs: Option<&Tensor>,
|
encoder_xs: Option<&Tensor>,
|
||||||
attn_mask: &Tensor,
|
attn_mask: &Tensor,
|
||||||
@ -331,7 +352,7 @@ impl Encoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||||
let xs = xs.apply(&self.embed_tokens)?;
|
let xs = xs.apply(&self.embed_tokens)?;
|
||||||
let xs = match self.embed_scale {
|
let xs = match self.embed_scale {
|
||||||
None => xs,
|
None => xs,
|
||||||
@ -342,7 +363,7 @@ impl Encoder {
|
|||||||
.forward(&xs, past_kv_len)?
|
.forward(&xs, past_kv_len)?
|
||||||
.unsqueeze(0)?;
|
.unsqueeze(0)?;
|
||||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||||
for layer in self.layers.iter() {
|
for layer in self.layers.iter_mut() {
|
||||||
xs = layer.forward(&xs)?
|
xs = layer.forward(&xs)?
|
||||||
}
|
}
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
@ -380,7 +401,7 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(
|
pub fn forward(
|
||||||
&self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
encoder_xs: Option<&Tensor>,
|
encoder_xs: Option<&Tensor>,
|
||||||
past_kv_len: usize,
|
past_kv_len: usize,
|
||||||
@ -396,7 +417,7 @@ impl Decoder {
|
|||||||
.forward(&xs, past_kv_len)?
|
.forward(&xs, past_kv_len)?
|
||||||
.unsqueeze(0)?;
|
.unsqueeze(0)?;
|
||||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||||
for layer in self.layers.iter() {
|
for layer in self.layers.iter_mut() {
|
||||||
xs = layer.forward(&xs, encoder_xs, attn_mask)?;
|
xs = layer.forward(&xs, encoder_xs, attn_mask)?;
|
||||||
}
|
}
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
@ -443,15 +464,20 @@ impl MTModel {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn encoder(&self) -> &Encoder {
|
pub fn encoder(&mut self) -> &mut Encoder {
|
||||||
&self.model.encoder
|
&mut self.model.encoder
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decoder(&self) -> &Decoder {
|
pub fn decoder(&mut self) -> &mut Decoder {
|
||||||
&self.model.decoder
|
&mut self.model.decoder
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> {
|
pub fn decode(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
encoder_xs: &Tensor,
|
||||||
|
past_kv_len: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let seq_len = xs.dim(1)?;
|
let seq_len = xs.dim(1)?;
|
||||||
let mask: Vec<_> = (0..seq_len)
|
let mask: Vec<_> = (0..seq_len)
|
||||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||||
@ -459,7 +485,7 @@ impl MTModel {
|
|||||||
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
|
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
|
||||||
self.model
|
self.model
|
||||||
.decoder
|
.decoder
|
||||||
.forward(xs, Some(encoder_xs), 0, &mask)?
|
.forward(xs, Some(encoder_xs), past_kv_len, &mask)?
|
||||||
.apply(&self.lm_head)?
|
.apply(&self.lm_head)?
|
||||||
.broadcast_add(&self.final_logits_bias)
|
.broadcast_add(&self.final_logits_bias)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user