mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add the kv-cache to the whisper wasm version. (#689)
* Add the kv-cache to the whisper wasm version. * Improve the handling of special tokens.
This commit is contained in:
@ -27,7 +27,7 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
pub fn decode(&self, wav_input: Vec<u8>) -> Result<String, JsError> {
|
pub fn decode(&mut self, wav_input: Vec<u8>) -> Result<String, JsError> {
|
||||||
let segments = self
|
let segments = self
|
||||||
.decoder
|
.decoder
|
||||||
.convert_and_run(&wav_input)
|
.convert_and_run(&wav_input)
|
||||||
|
@ -109,6 +109,7 @@ struct MultiHeadAttention {
|
|||||||
value: Linear,
|
value: Linear,
|
||||||
out: Linear,
|
out: Linear,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MultiHeadAttention {
|
impl MultiHeadAttention {
|
||||||
@ -123,14 +124,39 @@ impl MultiHeadAttention {
|
|||||||
value,
|
value,
|
||||||
out,
|
out,
|
||||||
n_head,
|
n_head,
|
||||||
|
kv_cache: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
xa: Option<&Tensor>,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
|
flush_cache: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let _timer = crate::Timer::new("MultiHeadAttention::forward");
|
let _timer = crate::Timer::new("MultiHeadAttention::forward");
|
||||||
let q = self.query.forward(x)?;
|
let q = self.query.forward(x)?;
|
||||||
let k = self.key.forward(xa.unwrap_or(x))?;
|
let (k, v) = match xa {
|
||||||
let v = self.value.forward(xa.unwrap_or(x))?;
|
None => {
|
||||||
|
let k = self.key.forward(x)?;
|
||||||
|
let v = self.value.forward(x)?;
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
|
Some(x) => {
|
||||||
|
if flush_cache {
|
||||||
|
self.kv_cache = None;
|
||||||
|
}
|
||||||
|
if let Some((k, v)) = &self.kv_cache {
|
||||||
|
(k.clone(), v.clone())
|
||||||
|
} else {
|
||||||
|
let k = self.key.forward(x)?;
|
||||||
|
let v = self.value.forward(x)?;
|
||||||
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
||||||
let out = self.out.forward(&wv)?;
|
let out = self.out.forward(&wv)?;
|
||||||
Ok(out)
|
Ok(out)
|
||||||
@ -151,18 +177,9 @@ impl MultiHeadAttention {
|
|||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let (_, n_ctx, n_state) = q.dims3()?;
|
let (_, n_ctx, n_state) = q.dims3()?;
|
||||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||||
let q = {
|
let q = (self.reshape_head(q)? * scale)?;
|
||||||
let _timer = crate::Timer::new("q::reshape");
|
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||||
(self.reshape_head(q)? * scale)?
|
let v = self.reshape_head(v)?.contiguous()?;
|
||||||
};
|
|
||||||
let k = {
|
|
||||||
let _timer = crate::Timer::new("k::reshape");
|
|
||||||
(self.reshape_head(k)?.transpose(2, 3)? * scale)?
|
|
||||||
};
|
|
||||||
let v = {
|
|
||||||
let _timer = crate::Timer::new("v::reshape-contiguous");
|
|
||||||
self.reshape_head(v)?.contiguous()?
|
|
||||||
};
|
|
||||||
let mut qk = {
|
let mut qk = {
|
||||||
let _timer = crate::Timer::new("qk::matmul");
|
let _timer = crate::Timer::new("qk::matmul");
|
||||||
q.matmul(&k)?
|
q.matmul(&k)?
|
||||||
@ -218,12 +235,20 @@ impl ResidualAttentionBlock {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
xa: Option<&Tensor>,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
|
flush_kv_cache: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let _timer = crate::Timer::new("ResidualAttentionBlock::forward");
|
let _timer = crate::Timer::new("ResidualAttentionBlock::forward");
|
||||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
|
let attn = self
|
||||||
|
.attn
|
||||||
|
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
|
||||||
let mut x = (x + attn)?;
|
let mut x = (x + attn)?;
|
||||||
if let Some((attn, ln)) = &self.cross_attn {
|
if let Some((attn, ln)) = &mut self.cross_attn {
|
||||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
|
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
||||||
}
|
}
|
||||||
let mlp = self.mlp_linear2.forward(
|
let mlp = self.mlp_linear2.forward(
|
||||||
&self
|
&self
|
||||||
@ -294,7 +319,7 @@ impl AudioEncoder {
|
|||||||
ln_post,
|
ln_post,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||||
let _timer = crate::Timer::new("AudioEncoder::forward");
|
let _timer = crate::Timer::new("AudioEncoder::forward");
|
||||||
let x = {
|
let x = {
|
||||||
let _timer = crate::Timer::new("conv1::forward");
|
let _timer = crate::Timer::new("conv1::forward");
|
||||||
@ -308,8 +333,8 @@ impl AudioEncoder {
|
|||||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter_mut() {
|
||||||
x = block.forward(&x, None, None)?
|
x = block.forward(&x, None, None, flush_kv_cache)?
|
||||||
}
|
}
|
||||||
let x = self.ln_post.forward(&x)?;
|
let x = self.ln_post.forward(&x)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
@ -353,14 +378,14 @@ impl TextDecoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||||
let x_dims = x.dims();
|
let x_dims = x.dims();
|
||||||
let last = x_dims[x_dims.len() - 1];
|
let last = x_dims[x_dims.len() - 1];
|
||||||
let token_embedding = self.token_embedding.forward(x)?;
|
let token_embedding = self.token_embedding.forward(x)?;
|
||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter_mut() {
|
||||||
x = block.forward(&x, Some(xa), Some(&self.mask))?;
|
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
||||||
}
|
}
|
||||||
let x = self.ln.forward(&x)?;
|
let x = self.ln.forward(&x)?;
|
||||||
let w = self
|
let w = self
|
||||||
|
@ -40,9 +40,13 @@ pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
|||||||
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||||
|
|
||||||
// Tokenizer dependent bits.
|
// Tokenizer dependent bits.
|
||||||
pub const SOT_TOKEN: u32 = 50257;
|
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||||
pub const EOT_TOKEN: u32 = 50256;
|
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||||
pub const NO_SPEECH_TOKEN: u32 = 50361;
|
const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||||
|
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||||
|
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||||
|
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||||
|
|
||||||
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
||||||
pub const SUPPRESS_TOKENS: [u32; 91] = [
|
pub const SUPPRESS_TOKENS: [u32; 91] = [
|
||||||
@ -71,11 +75,18 @@ pub struct Segment {
|
|||||||
pub dr: DecodingResult,
|
pub dr: DecodingResult,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
pub struct Decoder {
|
pub struct Decoder {
|
||||||
model: Whisper,
|
model: Whisper,
|
||||||
mel_filters: Vec<f32>,
|
mel_filters: Vec<f32>,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
suppress_tokens: Tensor,
|
suppress_tokens: Tensor,
|
||||||
|
sot_token: u32,
|
||||||
|
transcribe_token: u32,
|
||||||
|
translate_token: u32,
|
||||||
|
eot_token: u32,
|
||||||
|
no_speech_token: u32,
|
||||||
|
no_timestamps_token: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Decoder {
|
impl Decoder {
|
||||||
@ -94,37 +105,49 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
||||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||||
|
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||||
|
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||||
|
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
||||||
|
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||||
|
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
mel_filters,
|
mel_filters,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
suppress_tokens,
|
suppress_tokens,
|
||||||
|
sot_token,
|
||||||
|
transcribe_token,
|
||||||
|
translate_token,
|
||||||
|
eot_token,
|
||||||
|
no_speech_token,
|
||||||
|
no_timestamps_token,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
|
fn decode(&mut self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
|
||||||
let model = &self.model;
|
let model = &mut self.model;
|
||||||
let audio_features = model.encoder.forward(mel)?;
|
let audio_features = model.encoder.forward(mel, true)?;
|
||||||
console_log!("audio features: {:?}", audio_features.dims());
|
console_log!("audio features: {:?}", audio_features.dims());
|
||||||
let sample_len = model.config.max_target_positions / 2;
|
let sample_len = model.config.max_target_positions / 2;
|
||||||
let mut sum_logprob = 0f64;
|
let mut sum_logprob = 0f64;
|
||||||
let mut no_speech_prob = f64::NAN;
|
let mut no_speech_prob = f64::NAN;
|
||||||
let mut tokens = vec![SOT_TOKEN];
|
let mut tokens = vec![self.sot_token, self.transcribe_token];
|
||||||
for i in 0..sample_len {
|
for i in 0..sample_len {
|
||||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||||
|
|
||||||
// The model expects a batch dim but this inference loop does not handle
|
// The model expects a batch dim but this inference loop does not handle
|
||||||
// it so we add it at this point.
|
// it so we add it at this point.
|
||||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||||
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
|
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
|
|
||||||
// Extract the no speech probability on the first iteration by looking at the first
|
// Extract the no speech probability on the first iteration by looking at the first
|
||||||
// token logits and the probability for the according token.
|
// token logits and the probability for the according token.
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
||||||
.get(NO_SPEECH_TOKEN as usize)?
|
.get(self.no_speech_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,7 +173,7 @@ impl Decoder {
|
|||||||
let prob = softmax(&logits, candle::D::Minus1)?
|
let prob = softmax(&logits, candle::D::Minus1)?
|
||||||
.get(next_token as usize)?
|
.get(next_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
|
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sum_logprob += prob.ln();
|
sum_logprob += prob.ln();
|
||||||
@ -169,7 +192,7 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn decode_with_fallback(
|
fn decode_with_fallback(
|
||||||
&self,
|
&mut self,
|
||||||
segment: &Tensor,
|
segment: &Tensor,
|
||||||
rng: &mut StdRng,
|
rng: &mut StdRng,
|
||||||
) -> anyhow::Result<DecodingResult> {
|
) -> anyhow::Result<DecodingResult> {
|
||||||
@ -195,7 +218,7 @@ impl Decoder {
|
|||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
||||||
let mut rng = StdRng::seed_from_u64(299792458);
|
let mut rng = StdRng::seed_from_u64(299792458);
|
||||||
let (_, _, content_frames) = mel.dims3()?;
|
let (_, _, content_frames) = mel.dims3()?;
|
||||||
let mut seek = 0;
|
let mut seek = 0;
|
||||||
@ -239,7 +262,7 @@ impl Decoder {
|
|||||||
Ok(decoder)
|
Ok(decoder)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn convert_and_run(&self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
|
pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut wav_input = std::io::Cursor::new(wav_input);
|
let mut wav_input = std::io::Cursor::new(wav_input);
|
||||||
let (header, data) = wav::read(&mut wav_input)?;
|
let (header, data) = wav::read(&mut wav_input)?;
|
||||||
@ -262,6 +285,13 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
||||||
|
match tokenizer.token_to_id(token) {
|
||||||
|
None => candle::bail!("no token-id for {token}"),
|
||||||
|
Some(id) => Ok(id),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Communication to the worker happens through bincode, the model weights and configs are fetched
|
// Communication to the worker happens through bincode, the model weights and configs are fetched
|
||||||
// on the main thread and transfered via the following structure.
|
// on the main thread and transfered via the following structure.
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
@ -314,7 +344,7 @@ impl yew_agent::Worker for Worker {
|
|||||||
}
|
}
|
||||||
Err(err) => Err(format!("model creation error {err:?}")),
|
Err(err) => Err(format!("model creation error {err:?}")),
|
||||||
},
|
},
|
||||||
WorkerInput::DecodeTask { wav_bytes } => match &self.decoder {
|
WorkerInput::DecodeTask { wav_bytes } => match &mut self.decoder {
|
||||||
None => Err("model has not been set".to_string()),
|
None => Err("model has not been set".to_string()),
|
||||||
Some(decoder) => decoder
|
Some(decoder) => decoder
|
||||||
.convert_and_run(&wav_bytes)
|
.convert_and_run(&wav_bytes)
|
||||||
|
Reference in New Issue
Block a user