mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Clean the decode loop of the whisper example.
This commit is contained in:
@ -2,7 +2,8 @@
|
||||
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
||||
// TODO:
|
||||
// - kv-cache support?
|
||||
// - language detection?
|
||||
// - Language detection?
|
||||
// - Batch size greater than 1.
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
@ -31,6 +32,10 @@ const LOGPROB_THRESHOLD: f64 = -1.0;
|
||||
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
||||
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
|
||||
// Tokenizer dependent bits.
|
||||
const SOT_TOKEN: u32 = 50257;
|
||||
const EOT_TOKEN: u32 = 50256;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -83,11 +88,12 @@ impl Decode {
|
||||
let sample_len = model.config.n_text_ctx / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let no_speech_prob = f64::NAN;
|
||||
// TODO: 50257 is the start of transcipt token, be more principled about get initial tokens
|
||||
let mut tokens: Vec<u32> = vec![50257];
|
||||
let mut tokens = vec![SOT_TOKEN];
|
||||
for _i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?;
|
||||
// Insert a batch dim.
|
||||
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
@ -112,11 +118,10 @@ impl Decode {
|
||||
.softmax(logits.rank() - 1)?
|
||||
.get(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
sum_logprob += prob.ln();
|
||||
// 50256 is the eot token, TODO: parameterize this.
|
||||
if next_token == 50256 || tokens.len() > model.config.n_text_ctx {
|
||||
if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {
|
||||
break;
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
}
|
||||
let text = self
|
||||
.tokenizer
|
||||
@ -136,16 +141,19 @@ impl Decode {
|
||||
|
||||
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
|
||||
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
||||
let dr: DecodingResult = self.decode(segment, t)?;
|
||||
let dr: Result<DecodingResult> = self.decode(segment, t);
|
||||
if i == TEMPERATURES.len() - 1 {
|
||||
return Ok(dr);
|
||||
return dr;
|
||||
}
|
||||
// On errors, we try again with a different temperature.
|
||||
if let Ok(dr) = dr {
|
||||
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
|
||||
return Ok(dr);
|
||||
}
|
||||
}
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
@ -195,7 +203,7 @@ fn main() -> Result<()> {
|
||||
duration: segment_duration,
|
||||
dr,
|
||||
};
|
||||
println!("{seek} {segment:?}");
|
||||
println!("{seek}: {segment:?}");
|
||||
segments.push(segment)
|
||||
}
|
||||
Ok(())
|
||||
|
@ -1,3 +1,5 @@
|
||||
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use std::collections::HashMap;
|
||||
|
Reference in New Issue
Block a user