mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add the kv-cache to the whisper wasm version. (#689)
* Add the kv-cache to the whisper wasm version. * Improve the handling of special tokens.
This commit is contained in:
@ -40,9 +40,13 @@ pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
||||
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
|
||||
// Tokenizer dependent bits.
|
||||
pub const SOT_TOKEN: u32 = 50257;
|
||||
pub const EOT_TOKEN: u32 = 50256;
|
||||
pub const NO_SPEECH_TOKEN: u32 = 50361;
|
||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||
|
||||
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
||||
pub const SUPPRESS_TOKENS: [u32; 91] = [
|
||||
@ -71,11 +75,18 @@ pub struct Segment {
|
||||
pub dr: DecodingResult,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub struct Decoder {
|
||||
model: Whisper,
|
||||
mel_filters: Vec<f32>,
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Tensor,
|
||||
sot_token: u32,
|
||||
transcribe_token: u32,
|
||||
translate_token: u32,
|
||||
eot_token: u32,
|
||||
no_speech_token: u32,
|
||||
no_timestamps_token: u32,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
@ -94,37 +105,49 @@ impl Decoder {
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
mel_filters,
|
||||
tokenizer,
|
||||
suppress_tokens,
|
||||
sot_token,
|
||||
transcribe_token,
|
||||
translate_token,
|
||||
eot_token,
|
||||
no_speech_token,
|
||||
no_timestamps_token,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
|
||||
let model = &self.model;
|
||||
let audio_features = model.encoder.forward(mel)?;
|
||||
fn decode(&mut self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
|
||||
let model = &mut self.model;
|
||||
let audio_features = model.encoder.forward(mel, true)?;
|
||||
console_log!("audio features: {:?}", audio_features.dims());
|
||||
let sample_len = model.config.max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![SOT_TOKEN];
|
||||
let mut tokens = vec![self.sot_token, self.transcribe_token];
|
||||
for i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
||||
.get(NO_SPEECH_TOKEN as usize)?
|
||||
.get(self.no_speech_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
|
||||
@ -150,7 +173,7 @@ impl Decoder {
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.get(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
|
||||
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
||||
break;
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
@ -169,7 +192,7 @@ impl Decoder {
|
||||
}
|
||||
|
||||
fn decode_with_fallback(
|
||||
&self,
|
||||
&mut self,
|
||||
segment: &Tensor,
|
||||
rng: &mut StdRng,
|
||||
) -> anyhow::Result<DecodingResult> {
|
||||
@ -195,7 +218,7 @@ impl Decoder {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
||||
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let (_, _, content_frames) = mel.dims3()?;
|
||||
let mut seek = 0;
|
||||
@ -239,7 +262,7 @@ impl Decoder {
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
pub fn convert_and_run(&self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
|
||||
pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
|
||||
let device = Device::Cpu;
|
||||
let mut wav_input = std::io::Cursor::new(wav_input);
|
||||
let (header, data) = wav::read(&mut wav_input)?;
|
||||
@ -262,6 +285,13 @@ impl Decoder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
||||
match tokenizer.token_to_id(token) {
|
||||
None => candle::bail!("no token-id for {token}"),
|
||||
Some(id) => Ok(id),
|
||||
}
|
||||
}
|
||||
|
||||
// Communication to the worker happens through bincode, the model weights and configs are fetched
|
||||
// on the main thread and transfered via the following structure.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@ -314,7 +344,7 @@ impl yew_agent::Worker for Worker {
|
||||
}
|
||||
Err(err) => Err(format!("model creation error {err:?}")),
|
||||
},
|
||||
WorkerInput::DecodeTask { wav_bytes } => match &self.decoder {
|
||||
WorkerInput::DecodeTask { wav_bytes } => match &mut self.decoder {
|
||||
None => Err("model has not been set".to_string()),
|
||||
Some(decoder) => decoder
|
||||
.convert_and_run(&wav_bytes)
|
||||
|
Reference in New Issue
Block a user