mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Clean the decode loop of the whisper example.
This commit is contained in:
@ -2,7 +2,8 @@
|
|||||||
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
||||||
// TODO:
|
// TODO:
|
||||||
// - kv-cache support?
|
// - kv-cache support?
|
||||||
// - language detection?
|
// - Language detection?
|
||||||
|
// - Batch size greater than 1.
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
@ -31,6 +32,10 @@ const LOGPROB_THRESHOLD: f64 = -1.0;
|
|||||||
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
||||||
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||||
|
|
||||||
|
// Tokenizer dependent bits.
|
||||||
|
const SOT_TOKEN: u32 = 50257;
|
||||||
|
const EOT_TOKEN: u32 = 50256;
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -83,11 +88,12 @@ impl Decode {
|
|||||||
let sample_len = model.config.n_text_ctx / 2;
|
let sample_len = model.config.n_text_ctx / 2;
|
||||||
let mut sum_logprob = 0f64;
|
let mut sum_logprob = 0f64;
|
||||||
let no_speech_prob = f64::NAN;
|
let no_speech_prob = f64::NAN;
|
||||||
// TODO: 50257 is the start of transcipt token, be more principled about get initial tokens
|
let mut tokens = vec![SOT_TOKEN];
|
||||||
let mut tokens: Vec<u32> = vec![50257];
|
|
||||||
for _i in 0..sample_len {
|
for _i in 0..sample_len {
|
||||||
let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?;
|
let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?;
|
||||||
// Insert a batch dim.
|
|
||||||
|
// The model expects a batch dim but this inference loop does not handle
|
||||||
|
// it so we add it at this point.
|
||||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||||
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
|
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
@ -112,11 +118,10 @@ impl Decode {
|
|||||||
.softmax(logits.rank() - 1)?
|
.softmax(logits.rank() - 1)?
|
||||||
.get(next_token as usize)?
|
.get(next_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
sum_logprob += prob.ln();
|
if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {
|
||||||
// 50256 is the eot token, TODO: parameterize this.
|
|
||||||
if next_token == 50256 || tokens.len() > model.config.n_text_ctx {
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
sum_logprob += prob.ln();
|
||||||
}
|
}
|
||||||
let text = self
|
let text = self
|
||||||
.tokenizer
|
.tokenizer
|
||||||
@ -136,14 +141,17 @@ impl Decode {
|
|||||||
|
|
||||||
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
|
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
|
||||||
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
||||||
let dr: DecodingResult = self.decode(segment, t)?;
|
let dr: Result<DecodingResult> = self.decode(segment, t);
|
||||||
if i == TEMPERATURES.len() - 1 {
|
if i == TEMPERATURES.len() - 1 {
|
||||||
return Ok(dr);
|
return dr;
|
||||||
}
|
}
|
||||||
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|
// On errors, we try again with a different temperature.
|
||||||
|| dr.avg_logprob < LOGPROB_THRESHOLD;
|
if let Ok(dr) = dr {
|
||||||
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
|
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|
||||||
return Ok(dr);
|
|| dr.avg_logprob < LOGPROB_THRESHOLD;
|
||||||
|
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
|
||||||
|
return Ok(dr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
unreachable!()
|
unreachable!()
|
||||||
@ -195,7 +203,7 @@ fn main() -> Result<()> {
|
|||||||
duration: segment_duration,
|
duration: segment_duration,
|
||||||
dr,
|
dr,
|
||||||
};
|
};
|
||||||
println!("{seek} {segment:?}");
|
println!("{seek}: {segment:?}");
|
||||||
segments.push(segment)
|
segments.push(segment)
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
|
||||||
|
// back when using RUST_LIB_BACKTRACE=1.
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
Reference in New Issue
Block a user