mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Whisper tweaks (#85)
* Isolate the decoding bits of the whisper example. * Decode -> Decoder. * Add the suppress tokens filter. * More suppress tokens.
This commit is contained in:
@ -4,6 +4,7 @@
|
||||
// - kv-cache support?
|
||||
// - Language detection?
|
||||
// - Batch size greater than 1.
|
||||
// - More token filters (SuppressBlanks, ApplyTimestampRules).
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
@ -40,6 +41,16 @@ const SOT_TOKEN: u32 = 50257;
|
||||
const EOT_TOKEN: u32 = 50256;
|
||||
const NO_SPEECH_TOKEN: u32 = 50361;
|
||||
const NO_TIMESTAMP_TOKEN: u32 = 50362;
|
||||
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
||||
const SUPPRESS_TOKENS: [u32; 91] = [
|
||||
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
|
||||
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
|
||||
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
|
||||
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
|
||||
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
|
||||
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecodingResult {
|
||||
@ -58,13 +69,33 @@ struct Segment {
|
||||
dr: DecodingResult,
|
||||
}
|
||||
|
||||
struct Decode {
|
||||
struct Decoder {
|
||||
model: Whisper,
|
||||
rng: rand::rngs::StdRng,
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Tensor,
|
||||
}
|
||||
|
||||
impl Decode {
|
||||
impl Decoder {
|
||||
fn new(model: Whisper, tokenizer: Tokenizer, seed: u64, device: &Device) -> Result<Self> {
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
.map(|i| {
|
||||
if SUPPRESS_TOKENS.contains(&i) {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0f32
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
tokenizer,
|
||||
suppress_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||
let model = &self.model;
|
||||
let audio_features = model.encoder.forward(mel)?;
|
||||
@ -93,7 +124,9 @@ impl Decode {
|
||||
}
|
||||
|
||||
let (seq_len, _) = logits.shape().r2()?;
|
||||
let logits = logits.get(seq_len - 1)?;
|
||||
let logits = logits
|
||||
.get(seq_len - 1)?
|
||||
.broadcast_add(&self.suppress_tokens)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = (&logits / t)?.softmax(0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
@ -156,6 +189,33 @@ impl Decode {
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> {
|
||||
let (_, _, content_frames) = mel.shape().r3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
let start = std::time::Instant::now();
|
||||
while seek < content_frames {
|
||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let dr = self.decode_with_fallback(&mel_segment)?;
|
||||
seek += segment_size;
|
||||
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
||||
println!("no speech detected, skipping {seek} {dr:?}");
|
||||
continue;
|
||||
}
|
||||
let segment = Segment {
|
||||
start: time_offset,
|
||||
duration: segment_duration,
|
||||
dr,
|
||||
};
|
||||
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
|
||||
segments.push(segment)
|
||||
}
|
||||
Ok(segments)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -168,11 +228,13 @@ struct Args {
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?search=whisper
|
||||
/// The model to use, check out available models:
|
||||
/// https://huggingface.co/models?search=whisper
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// The input to be processed, in wav formats, will default to `jfk.wav` https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
|
||||
/// The input to be processed, in wav formats, will default to `jfk.wav`
|
||||
/// https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
|
||||
#[arg(long)]
|
||||
input: Option<String>,
|
||||
|
||||
@ -196,8 +258,6 @@ async fn main() -> Result<()> {
|
||||
} else {
|
||||
Device::new_cuda(0)?
|
||||
};
|
||||
let rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
||||
|
||||
let default_model = "openai/whisper-tiny.en".to_string();
|
||||
let path = std::path::PathBuf::from(default_model.clone());
|
||||
let default_revision = "refs/pr/15".to_string();
|
||||
@ -267,37 +327,10 @@ async fn main() -> Result<()> {
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let model = Whisper::load(&vb, config)?;
|
||||
let mut dc = Decode {
|
||||
model,
|
||||
rng,
|
||||
tokenizer,
|
||||
};
|
||||
|
||||
let (_, _, content_frames) = mel.shape().r3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
let start = std::time::Instant::now();
|
||||
while seek < content_frames {
|
||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let dr = dc.decode_with_fallback(&mel_segment)?;
|
||||
seek += segment_size;
|
||||
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
||||
println!("no speech detected, skipping {seek} {dr:?}");
|
||||
continue;
|
||||
}
|
||||
let segment = Segment {
|
||||
start: time_offset,
|
||||
duration: segment_duration,
|
||||
dr,
|
||||
};
|
||||
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
|
||||
segments.push(segment)
|
||||
}
|
||||
let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?;
|
||||
dc.run(&mel)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ impl<'a> VarBuilder<'a> {
|
||||
pub fn from_safetensors(
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let mut routing = HashMap::new();
|
||||
for (index, sf) in safetensors.iter().enumerate() {
|
||||
@ -25,7 +25,7 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
Self {
|
||||
safetensors: Some((routing, safetensors)),
|
||||
device,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user