Add the kv-cache to the whisper wasm version. (#689)

* Add the kv-cache to the whisper wasm version.

* Improve the handling of special tokens.
This commit is contained in:
Laurent Mazare
2023-08-31 10:37:44 +02:00
committed by GitHub
parent db59816087
commit 94aa234dfd
3 changed files with 95 additions and 40 deletions

View File

@ -40,9 +40,13 @@ pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
pub const SOT_TOKEN: u32 = 50257;
pub const EOT_TOKEN: u32 = 50256;
pub const NO_SPEECH_TOKEN: u32 = 50361;
const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
const TRANSLATE_TOKEN: &str = "<|translate|>";
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
// From the _get_suppress_tokens function + 50362 (no timestamp)
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
pub const SUPPRESS_TOKENS: [u32; 91] = [
@ -71,11 +75,18 @@ pub struct Segment {
pub dr: DecodingResult,
}
#[allow(unused)]
pub struct Decoder {
model: Whisper,
mel_filters: Vec<f32>,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
transcribe_token: u32,
translate_token: u32,
eot_token: u32,
no_speech_token: u32,
no_timestamps_token: u32,
}
impl Decoder {
@ -94,37 +105,49 @@ impl Decoder {
}
})
.collect();
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
Ok(Self {
model,
mel_filters,
tokenizer,
suppress_tokens,
sot_token,
transcribe_token,
translate_token,
eot_token,
no_speech_token,
no_timestamps_token,
})
}
fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
let model = &self.model;
let audio_features = model.encoder.forward(mel)?;
fn decode(&mut self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
let model = &mut self.model;
let audio_features = model.encoder.forward(mel, true)?;
console_log!("audio features: {:?}", audio_features.dims());
let sample_len = model.config.max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![SOT_TOKEN];
let mut tokens = vec![self.sot_token, self.transcribe_token];
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
let logits = logits.squeeze(0)?;
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)?
.get(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}
@ -150,7 +173,7 @@ impl Decoder {
let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
break;
}
sum_logprob += prob.ln();
@ -169,7 +192,7 @@ impl Decoder {
}
fn decode_with_fallback(
&self,
&mut self,
segment: &Tensor,
rng: &mut StdRng,
) -> anyhow::Result<DecodingResult> {
@ -195,7 +218,7 @@ impl Decoder {
unreachable!()
}
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
let mut rng = StdRng::seed_from_u64(299792458);
let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0;
@ -239,7 +262,7 @@ impl Decoder {
Ok(decoder)
}
pub fn convert_and_run(&self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
let device = Device::Cpu;
let mut wav_input = std::io::Cursor::new(wav_input);
let (header, data) = wav::read(&mut wav_input)?;
@ -262,6 +285,13 @@ impl Decoder {
}
}
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
match tokenizer.token_to_id(token) {
None => candle::bail!("no token-id for {token}"),
Some(id) => Ok(id),
}
}
// Communication to the worker happens through bincode, the model weights and configs are fetched
// on the main thread and transfered via the following structure.
#[derive(Serialize, Deserialize)]
@ -314,7 +344,7 @@ impl yew_agent::Worker for Worker {
}
Err(err) => Err(format!("model creation error {err:?}")),
},
WorkerInput::DecodeTask { wav_bytes } => match &self.decoder {
WorkerInput::DecodeTask { wav_bytes } => match &mut self.decoder {
None => Err("model has not been set".to_string()),
Some(decoder) => decoder
.convert_and_run(&wav_bytes)