From 9fe7a428954e774f760830a063376e9fcf228e13 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 4 Jul 2023 22:18:07 +0100 Subject: [PATCH] More whisper sampling. --- candle-examples/examples/whisper/main.rs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 45ea74a0..647ac6f9 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -619,13 +619,20 @@ impl Decode { fn decode(&mut self, mel: &Tensor, t: f64) -> Result { let model = &self.model; let audio_features = model.encoder.forward(mel)?; + println!("audio features: {:?}", audio_features.dims()); let sample_len = model.config.n_text_ctx / 2; let mut sum_logprob = 0f64; let no_speech_prob = f64::NAN; - let mut tokens: Vec = vec![]; // TODO: get initial tokens + // TODO: 50257 is the start of transcipt token, be more principled about get initial tokens + let mut tokens: Vec = vec![50257]; for _i in 0..sample_len { let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?; + // Insert a batch dim. + let tokens_t = tokens_t.unsqueeze(0)?; let logits = model.decoder.forward(&tokens_t, &audio_features)?; + let logits = logits.squeeze(0)?; + let (seq_len, _) = logits.shape().r2()?; + let logits = logits.get(seq_len - 1)?; let next_token = if t > 0f64 { let prs = (&logits / t)?.softmax(logits.rank() - 1)?; let logits_v: Vec = prs.to_vec1()?; @@ -698,7 +705,6 @@ fn main() -> Result<()> { let input = input.deserialize()?; let mel = input.tensor("mel", &device)?; println!("loaded mel: {:?}", mel.dims()); - let mel = if mel.rank() > 2 { mel.squeeze(0)? } else { mel }; let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; let weights = weights.deserialize()?; @@ -710,14 +716,13 @@ fn main() -> Result<()> { tokenizer, }; - let (_, content_frames) = mel.shape().r2()?; - let content_frames = content_frames - N_SAMPLES; + let (_, _, content_frames) = mel.shape().r3()?; let mut seek = 0; let mut segments = vec![]; while seek < content_frames { let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; let segment_size = usize::min(content_frames - seek, N_FRAMES); - let mel_segment = mel.narrow(1, seek, segment_size)?; + let mel_segment = mel.narrow(2, seek, segment_size)?; let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; let dr = dc.decode_with_fallback(&mel_segment)?; seek += segment_size;