mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a KV cache to T5. (#873)
* Add a KV cache to T5. * Suggest using release mode. * Use the kv cache in decoding. * Add a comment.
This commit is contained in:
@ -54,7 +54,7 @@ pub struct Config {
|
||||
is_decoder: bool,
|
||||
is_encoder_decoder: bool,
|
||||
#[serde(default = "default_use_cache")]
|
||||
use_cache: bool,
|
||||
pub use_cache: bool,
|
||||
pub pad_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
}
|
||||
@ -245,10 +245,17 @@ struct T5Attention {
|
||||
relative_attention_num_buckets: usize,
|
||||
relative_attention_max_distance: usize,
|
||||
inner_dim: usize,
|
||||
use_cache: bool,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl T5Attention {
|
||||
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(
|
||||
has_relative_attention_bias: bool,
|
||||
decoder: bool,
|
||||
vb: VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
||||
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
||||
@ -275,11 +282,13 @@ impl T5Attention {
|
||||
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
|
||||
relative_attention_max_distance: cfg.relative_attention_max_distance,
|
||||
inner_dim,
|
||||
use_cache: cfg.use_cache && decoder,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
key_value_states: Option<&Tensor>,
|
||||
@ -287,7 +296,6 @@ impl T5Attention {
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
// Performs Self-attention (if key_value_states is None) or attention
|
||||
// over source sentence (provided by key_value_states).
|
||||
// TODO: kv caching.
|
||||
let kv_input = match key_value_states {
|
||||
None => xs,
|
||||
Some(key_value_states) => key_value_states,
|
||||
@ -301,14 +309,22 @@ impl T5Attention {
|
||||
.reshape((b_sz, q_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
let mut k = k
|
||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let v = v
|
||||
let mut v = v
|
||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
if self.use_cache {
|
||||
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
|
||||
k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
|
||||
v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
};
|
||||
// TODO: Use flash_attn.
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
let scores = match mask {
|
||||
@ -394,8 +410,8 @@ struct T5LayerSelfAttention {
|
||||
}
|
||||
|
||||
impl T5LayerSelfAttention {
|
||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
|
||||
fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
@ -405,7 +421,7 @@ impl T5LayerSelfAttention {
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
@ -426,8 +442,8 @@ struct T5LayerCrossAttention {
|
||||
}
|
||||
|
||||
impl T5LayerCrossAttention {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let cross_attention = T5Attention::load(false, vb.pp("EncDecAttention"), cfg)?;
|
||||
fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
@ -437,7 +453,7 @@ impl T5LayerCrossAttention {
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
hidden_states: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
key_value_states: &Tensor,
|
||||
@ -462,11 +478,17 @@ struct T5Block {
|
||||
}
|
||||
|
||||
impl T5Block {
|
||||
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(
|
||||
has_relative_attention_bias: bool,
|
||||
decoder: bool,
|
||||
vb: VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("layer");
|
||||
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
|
||||
let self_attn =
|
||||
T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
|
||||
let cross_attn = if cfg.is_decoder {
|
||||
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
|
||||
Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -480,19 +502,28 @@ impl T5Block {
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
// TODO: Cache masks
|
||||
let mask = match self.cross_attn.is_some() {
|
||||
true => Some(get_mask(xs.dim(1)?, xs.device())?),
|
||||
true => {
|
||||
let mask_len = xs.dim(1)?;
|
||||
// If the input seq length is 1, no need for a mask, this is also helpful to avoid shape
|
||||
// issues when using the KV cache in the decoder.
|
||||
if mask_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(mask_len, xs.device())?)
|
||||
}
|
||||
}
|
||||
false => None,
|
||||
};
|
||||
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
|
||||
// TODO: clamp for f16?
|
||||
if let Some(cross_attn) = &self.cross_attn {
|
||||
if let Some(cross_attn) = &mut self.cross_attn {
|
||||
(xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
|
||||
// TODO: clamp for f16?
|
||||
}
|
||||
@ -510,9 +541,9 @@ struct T5Stack {
|
||||
}
|
||||
|
||||
impl T5Stack {
|
||||
fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
|
||||
fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
|
||||
let block = (0..cfg.num_layers)
|
||||
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
|
||||
.map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let final_layer_norm = T5LayerNorm::load(
|
||||
cfg.d_model,
|
||||
@ -527,14 +558,14 @@ impl T5Stack {
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let mut hidden_states = input_embeds;
|
||||
let mut position_bias = None;
|
||||
for block in self.block.iter() {
|
||||
for block in self.block.iter_mut() {
|
||||
(hidden_states, position_bias) = block.forward(
|
||||
&hidden_states,
|
||||
position_bias.as_ref(),
|
||||
@ -555,14 +586,14 @@ impl T5EncoderModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared = Arc::new(shared);
|
||||
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
|
||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
device: vb.device().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
self.encoder.forward(input_ids, None)
|
||||
}
|
||||
|
||||
@ -589,13 +620,13 @@ impl T5ForConditionalGeneration {
|
||||
encoder_cfg.is_decoder = false;
|
||||
encoder_cfg.use_cache = false;
|
||||
encoder_cfg.is_encoder_decoder = false;
|
||||
let encoder = T5Stack::load(vb.pp("encoder"), &shared, &encoder_cfg)?;
|
||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?;
|
||||
|
||||
let mut decoder_cfg = cfg.clone();
|
||||
decoder_cfg.is_decoder = true;
|
||||
decoder_cfg.is_encoder_decoder = false;
|
||||
decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
|
||||
let decoder = T5Stack::load(vb.pp("decoder"), &shared, &decoder_cfg)?;
|
||||
let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
|
||||
|
||||
Ok(Self {
|
||||
encoder,
|
||||
@ -605,7 +636,7 @@ impl T5ForConditionalGeneration {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
|
||||
let encoder_output = self.encoder.forward(input_ids, None)?;
|
||||
let decoder_output = self
|
||||
.decoder
|
||||
|
Reference in New Issue
Block a user