mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add a KV cache to whisper. (#426)
This commit is contained in:
@ -109,8 +109,8 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||||
let model = &self.model;
|
let model = &mut self.model;
|
||||||
let audio_features = model.encoder.forward(mel)?;
|
let audio_features = model.encoder.forward(mel, true)?;
|
||||||
println!("audio features: {:?}", audio_features.dims());
|
println!("audio features: {:?}", audio_features.dims());
|
||||||
let sample_len = model.config.max_target_positions / 2;
|
let sample_len = model.config.max_target_positions / 2;
|
||||||
let mut sum_logprob = 0f64;
|
let mut sum_logprob = 0f64;
|
||||||
@ -126,7 +126,7 @@ impl Decoder {
|
|||||||
// The model expects a batch dim but this inference loop does not handle
|
// The model expects a batch dim but this inference loop does not handle
|
||||||
// it so we add it at this point.
|
// it so we add it at this point.
|
||||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||||
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
|
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
|
|
||||||
// Extract the no speech probability on the first iteration by looking at the first
|
// Extract the no speech probability on the first iteration by looking at the first
|
||||||
@ -393,10 +393,10 @@ fn main() -> Result<()> {
|
|||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||||
let model = Whisper::load(&vb, config)?;
|
let mut model = Whisper::load(&vb, config)?;
|
||||||
|
|
||||||
let language_token = match (args.model.is_multilingual(), args.language) {
|
let language_token = match (args.model.is_multilingual(), args.language) {
|
||||||
(true, None) => Some(multilingual::detect_language(&model, &tokenizer, &mel)?),
|
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
|
||||||
(false, None) => None,
|
(false, None) => None,
|
||||||
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
|
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
|
||||||
Ok(token_id) => Some(token_id),
|
Ok(token_id) => Some(token_id),
|
||||||
|
@ -105,6 +105,7 @@ struct MultiHeadAttention {
|
|||||||
out: Linear,
|
out: Linear,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MultiHeadAttention {
|
impl MultiHeadAttention {
|
||||||
@ -121,14 +122,39 @@ impl MultiHeadAttention {
|
|||||||
out,
|
out,
|
||||||
n_head,
|
n_head,
|
||||||
span,
|
span,
|
||||||
|
kv_cache: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
xa: Option<&Tensor>,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
|
flush_cache: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let q = self.query.forward(x)?;
|
let q = self.query.forward(x)?;
|
||||||
let k = self.key.forward(xa.unwrap_or(x))?;
|
let (k, v) = match xa {
|
||||||
let v = self.value.forward(xa.unwrap_or(x))?;
|
None => {
|
||||||
|
let k = self.key.forward(x)?;
|
||||||
|
let v = self.value.forward(x)?;
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
|
Some(x) => {
|
||||||
|
if flush_cache {
|
||||||
|
self.kv_cache = None;
|
||||||
|
}
|
||||||
|
if let Some((k, v)) = &self.kv_cache {
|
||||||
|
(k.clone(), v.clone())
|
||||||
|
} else {
|
||||||
|
let k = self.key.forward(x)?;
|
||||||
|
let v = self.value.forward(x)?;
|
||||||
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
||||||
let out = self.out.forward(&wv)?;
|
let out = self.out.forward(&wv)?;
|
||||||
Ok(out)
|
Ok(out)
|
||||||
@ -201,12 +227,20 @@ impl ResidualAttentionBlock {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
xa: Option<&Tensor>,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
|
flush_kv_cache: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
|
let attn = self
|
||||||
|
.attn
|
||||||
|
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
|
||||||
let mut x = (x + attn)?;
|
let mut x = (x + attn)?;
|
||||||
if let Some((attn, ln)) = &self.cross_attn {
|
if let Some((attn, ln)) = &mut self.cross_attn {
|
||||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
|
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
||||||
}
|
}
|
||||||
let mlp = self.mlp_linear2.forward(
|
let mlp = self.mlp_linear2.forward(
|
||||||
&self
|
&self
|
||||||
@ -283,7 +317,7 @@ impl AudioEncoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let x = {
|
let x = {
|
||||||
let _enter = self.conv1_span.enter();
|
let _enter = self.conv1_span.enter();
|
||||||
@ -297,8 +331,8 @@ impl AudioEncoder {
|
|||||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter_mut() {
|
||||||
x = block.forward(&x, None, None)?
|
x = block.forward(&x, None, None, flush_kv_cache)?
|
||||||
}
|
}
|
||||||
let x = self.ln_post.forward(&x)?;
|
let x = self.ln_post.forward(&x)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
@ -344,15 +378,15 @@ impl TextDecoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let x_dims = x.dims();
|
let x_dims = x.dims();
|
||||||
let last = x_dims[x_dims.len() - 1];
|
let last = x_dims[x_dims.len() - 1];
|
||||||
let token_embedding = self.token_embedding.forward(x)?;
|
let token_embedding = self.token_embedding.forward(x)?;
|
||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter_mut() {
|
||||||
x = block.forward(&x, Some(xa), Some(&self.mask))?;
|
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
||||||
}
|
}
|
||||||
let x = self.ln.forward(&x)?;
|
let x = self.ln.forward(&x)?;
|
||||||
let w = self
|
let w = self
|
||||||
@ -383,9 +417,14 @@ impl Whisper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
pub fn forward(
|
||||||
let enc = self.encoder.forward(mel)?;
|
&mut self,
|
||||||
let dec = self.decoder.forward(tokens, &enc)?;
|
mel: &Tensor,
|
||||||
|
tokens: &Tensor,
|
||||||
|
flush_kv_cache: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let enc = self.encoder.forward(mel, flush_kv_cache)?;
|
||||||
|
let dec = self.decoder.forward(tokens, &enc, flush_kv_cache)?;
|
||||||
Ok(dec)
|
Ok(dec)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -105,19 +105,19 @@ const LANGUAGES: [(&str, &str); 99] = [
|
|||||||
];
|
];
|
||||||
|
|
||||||
/// Returns the token id for the selected language.
|
/// Returns the token id for the selected language.
|
||||||
pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
||||||
let device = mel.device();
|
let device = mel.device();
|
||||||
let language_token_ids = LANGUAGES
|
let language_token_ids = LANGUAGES
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
|
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
|
||||||
let audio_features = model.encoder.forward(mel)?;
|
let audio_features = model.encoder.forward(mel, true)?;
|
||||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||||
let logits = model
|
let logits = model
|
||||||
.decoder
|
.decoder
|
||||||
.forward(&tokens, &audio_features)?
|
.forward(&tokens, &audio_features, true)?
|
||||||
.i(0)?
|
.i(0)?
|
||||||
.i(0)?;
|
.i(0)?;
|
||||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||||
|
Reference in New Issue
Block a user