mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a KV cache to falcon. (#104)
This commit is contained in:
@ -51,7 +51,13 @@ impl TextGeneration {
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||
let context_size = if self.model.config().use_cache && index > 0 {
|
||||
1
|
||||
} else {
|
||||
tokens.len()
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
|
@ -2,6 +2,8 @@ use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor, D};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
pub struct VarBuilder<'a> {
|
||||
safetensors: Option<(HashMap<String, usize>, Vec<SafeTensors<'a>>)>,
|
||||
dtype: DType,
|
||||
@ -180,23 +182,23 @@ impl Embedding {
|
||||
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_range: f64,
|
||||
use_cache: bool,
|
||||
bos_token_id: u32,
|
||||
eos_token_id: u32,
|
||||
hidden_dropout: f64,
|
||||
attention_dropout: f64,
|
||||
n_head_kv: Option<usize>,
|
||||
alibi: bool,
|
||||
new_decoder_architecture: bool,
|
||||
multi_query: bool,
|
||||
parallel_attn: bool,
|
||||
bias: bool,
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub layer_norm_epsilon: f64,
|
||||
pub initializer_range: f64,
|
||||
pub use_cache: bool,
|
||||
pub bos_token_id: u32,
|
||||
pub eos_token_id: u32,
|
||||
pub hidden_dropout: f64,
|
||||
pub attention_dropout: f64,
|
||||
pub n_head_kv: Option<usize>,
|
||||
pub alibi: bool,
|
||||
pub new_decoder_architecture: bool,
|
||||
pub multi_query: bool,
|
||||
pub parallel_attn: bool,
|
||||
pub bias: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@ -292,9 +294,10 @@ impl FalconRotaryEmbedding {
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let inv_freq = Tensor::new(inv_freq.as_slice(), &vb.device)?;
|
||||
let cache = None;
|
||||
Ok(Self { inv_freq, cache })
|
||||
Ok(Self {
|
||||
inv_freq: Tensor::new(inv_freq.as_slice(), &vb.device)?,
|
||||
cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn cos_sin(
|
||||
@ -320,9 +323,16 @@ impl FalconRotaryEmbedding {
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
fn forward(&mut self, query: &Tensor, key: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
fn forward(
|
||||
&mut self,
|
||||
query: &Tensor,
|
||||
key: &Tensor,
|
||||
past_kv_len: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_batch, seq_len, _head_dim) = query.shape().r3()?;
|
||||
let (cos, sin) = self.cos_sin(seq_len, &query.device(), query.dtype())?;
|
||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, &query.device(), query.dtype())?;
|
||||
let cos = cos.narrow(0, past_kv_len, seq_len)?;
|
||||
let sin = sin.narrow(0, past_kv_len, seq_len)?;
|
||||
let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?;
|
||||
let ks = (key.broadcast_mul(&cos)? + &rotate_half(key)?.broadcast_mul(&sin)?)?;
|
||||
Ok((qs, ks))
|
||||
@ -341,8 +351,10 @@ struct FalconAttention {
|
||||
query_key_value: Linear,
|
||||
dense: Linear,
|
||||
maybe_rotary: Option<FalconRotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
inv_norm_factor: f64,
|
||||
multi_query: bool,
|
||||
use_cache: bool,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
n_head_kv: usize,
|
||||
@ -381,8 +393,10 @@ impl FalconAttention {
|
||||
query_key_value,
|
||||
dense,
|
||||
maybe_rotary,
|
||||
kv_cache: None,
|
||||
inv_norm_factor: 1. / (head_dim as f64).sqrt(),
|
||||
multi_query: cfg.multi_query,
|
||||
use_cache: cfg.use_cache,
|
||||
num_heads: cfg.num_attention_heads,
|
||||
n_head_kv: cfg.n_head_kv.unwrap_or(1),
|
||||
head_dim,
|
||||
@ -408,50 +422,60 @@ impl FalconAttention {
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, mask: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
let fused_qkv = self.query_key_value.forward(x)?;
|
||||
let head_dim = self.head_dim;
|
||||
let (query, key, value) = self.split_heads(&fused_qkv)?;
|
||||
let (b_sz, q_len, _, _) = query.shape().r4()?;
|
||||
let (b_sz, seq_len, _, _) = query.shape().r4()?;
|
||||
let query = query
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz * self.num_heads, q_len, head_dim))?;
|
||||
.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
|
||||
let key = key
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz * self.n_head_kv, q_len, head_dim))?;
|
||||
.reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
|
||||
let value = value
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz * self.n_head_kv, q_len, head_dim))?;
|
||||
.reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
|
||||
let (query, key) = if let Some(r) = &mut self.maybe_rotary {
|
||||
r.forward(&query, &key)?
|
||||
r.forward(&query, &key, past_kv_len)?
|
||||
} else {
|
||||
(query, key)
|
||||
};
|
||||
let (mut key, mut value) = (key, value);
|
||||
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?;
|
||||
// TODO: layer_past, use_cache?
|
||||
let query = query.reshape((b_sz * self.num_heads, q_len, head_dim))?;
|
||||
let key = key.reshape((b_sz * self.n_head_kv, q_len, head_dim))?;
|
||||
let value = value.reshape((b_sz * self.n_head_kv, q_len, head_dim))?;
|
||||
if self.use_cache {
|
||||
if let Some((cache_k, cache_v)) = &self.kv_cache {
|
||||
// TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
|
||||
// arbitrarily large sizes.
|
||||
key = Tensor::cat(&[cache_k, &key], 1)?.contiguous()?;
|
||||
value = Tensor::cat(&[cache_v, &value], 1)?.contiguous()?;
|
||||
}
|
||||
self.kv_cache = Some((key.clone(), value.clone()))
|
||||
}
|
||||
let query = query.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
|
||||
let all_len = past_kv_len + seq_len;
|
||||
let key = key.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
|
||||
let value = value.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
|
||||
|
||||
let (key, value) = if self.n_head_kv == 1 {
|
||||
(
|
||||
key.broadcast_as(query.dims())?,
|
||||
value.broadcast_as(query.dims())?,
|
||||
key.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
|
||||
value.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
|
||||
)
|
||||
} else {
|
||||
(key, value)
|
||||
};
|
||||
|
||||
// Only handle alibi is None here, and non-flash attention.
|
||||
// Only handle the case where alibi is None here, and non-flash attention.
|
||||
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
|
||||
let attention_scores = attention_scores
|
||||
.broadcast_add(&mask.squeeze(1)?)?
|
||||
.softmax(D::Minus1)?;
|
||||
let attn_output = attention_scores
|
||||
.matmul(&value)?
|
||||
.reshape((b_sz, self.num_heads, q_len, head_dim))?
|
||||
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.num_heads * head_dim))?;
|
||||
.reshape((b_sz, seq_len, self.num_heads * head_dim))?;
|
||||
let attn_output = self.dense.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
@ -524,10 +548,10 @@ impl FalconDecoderLayer {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, mask: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
let residual = x.clone();
|
||||
let ln_attn = self.inp_layernorm.forward(x)?;
|
||||
let attn_output = self.self_attention.forward(&ln_attn, mask)?;
|
||||
let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?;
|
||||
let (residual, ln_mlp) = match &self.post_attention_layernorm {
|
||||
None => (residual, ln_attn),
|
||||
Some(pal) => {
|
||||
@ -574,6 +598,10 @@ fn prepare_attn_mask(b_sz: usize, seq_len: usize) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
impl Falcon {
|
||||
pub fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn load(vb: &VarBuilder, cfg: Config) -> Result<Self> {
|
||||
let word_embeddings = Embedding::load(
|
||||
cfg.vocab_size,
|
||||
@ -603,9 +631,13 @@ impl Falcon {
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.shape().r2()?;
|
||||
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
|
||||
let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
|
||||
Some((k, _)) => k.dim(1)?,
|
||||
None => 0,
|
||||
};
|
||||
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(&input_ids.device())?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
hidden_state = block.forward(&hidden_state, &causal_mask)?;
|
||||
hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
|
||||
}
|
||||
let hidden_state = self.ln_f.forward(&hidden_state)?;
|
||||
let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
|
||||
|
Reference in New Issue
Block a user