mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Whisper tweaks (#85)
* Isolate the decoding bits of the whisper example. * Decode -> Decoder. * Add the suppress tokens filter. * More suppress tokens.
This commit is contained in:
@ -4,6 +4,7 @@
|
|||||||
// - kv-cache support?
|
// - kv-cache support?
|
||||||
// - Language detection?
|
// - Language detection?
|
||||||
// - Batch size greater than 1.
|
// - Batch size greater than 1.
|
||||||
|
// - More token filters (SuppressBlanks, ApplyTimestampRules).
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
@ -40,6 +41,16 @@ const SOT_TOKEN: u32 = 50257;
|
|||||||
const EOT_TOKEN: u32 = 50256;
|
const EOT_TOKEN: u32 = 50256;
|
||||||
const NO_SPEECH_TOKEN: u32 = 50361;
|
const NO_SPEECH_TOKEN: u32 = 50361;
|
||||||
const NO_TIMESTAMP_TOKEN: u32 = 50362;
|
const NO_TIMESTAMP_TOKEN: u32 = 50362;
|
||||||
|
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
||||||
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
||||||
|
const SUPPRESS_TOKENS: [u32; 91] = [
|
||||||
|
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
|
||||||
|
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
|
||||||
|
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
|
||||||
|
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
|
||||||
|
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
|
||||||
|
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
|
||||||
|
];
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct DecodingResult {
|
struct DecodingResult {
|
||||||
@ -58,13 +69,33 @@ struct Segment {
|
|||||||
dr: DecodingResult,
|
dr: DecodingResult,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Decode {
|
struct Decoder {
|
||||||
model: Whisper,
|
model: Whisper,
|
||||||
rng: rand::rngs::StdRng,
|
rng: rand::rngs::StdRng,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
|
suppress_tokens: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Decode {
|
impl Decoder {
|
||||||
|
fn new(model: Whisper, tokenizer: Tokenizer, seed: u64, device: &Device) -> Result<Self> {
|
||||||
|
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||||
|
.map(|i| {
|
||||||
|
if SUPPRESS_TOKENS.contains(&i) {
|
||||||
|
f32::NEG_INFINITY
|
||||||
|
} else {
|
||||||
|
0f32
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||||
|
Ok(Self {
|
||||||
|
model,
|
||||||
|
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||||
|
tokenizer,
|
||||||
|
suppress_tokens,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||||
let model = &self.model;
|
let model = &self.model;
|
||||||
let audio_features = model.encoder.forward(mel)?;
|
let audio_features = model.encoder.forward(mel)?;
|
||||||
@ -93,7 +124,9 @@ impl Decode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let (seq_len, _) = logits.shape().r2()?;
|
let (seq_len, _) = logits.shape().r2()?;
|
||||||
let logits = logits.get(seq_len - 1)?;
|
let logits = logits
|
||||||
|
.get(seq_len - 1)?
|
||||||
|
.broadcast_add(&self.suppress_tokens)?;
|
||||||
let next_token = if t > 0f64 {
|
let next_token = if t > 0f64 {
|
||||||
let prs = (&logits / t)?.softmax(0)?;
|
let prs = (&logits / t)?.softmax(0)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
@ -156,6 +189,33 @@ impl Decode {
|
|||||||
}
|
}
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> {
|
||||||
|
let (_, _, content_frames) = mel.shape().r3()?;
|
||||||
|
let mut seek = 0;
|
||||||
|
let mut segments = vec![];
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
while seek < content_frames {
|
||||||
|
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||||
|
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||||
|
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||||
|
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||||
|
let dr = self.decode_with_fallback(&mel_segment)?;
|
||||||
|
seek += segment_size;
|
||||||
|
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
||||||
|
println!("no speech detected, skipping {seek} {dr:?}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let segment = Segment {
|
||||||
|
start: time_offset,
|
||||||
|
duration: segment_duration,
|
||||||
|
dr,
|
||||||
|
};
|
||||||
|
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
|
||||||
|
segments.push(segment)
|
||||||
|
}
|
||||||
|
Ok(segments)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -168,11 +228,13 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
/// The model to use, check out available models: https://huggingface.co/models?search=whisper
|
/// The model to use, check out available models:
|
||||||
|
/// https://huggingface.co/models?search=whisper
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
/// The input to be processed, in wav formats, will default to `jfk.wav` https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
|
/// The input to be processed, in wav formats, will default to `jfk.wav`
|
||||||
|
/// https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
input: Option<String>,
|
input: Option<String>,
|
||||||
|
|
||||||
@ -196,8 +258,6 @@ async fn main() -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
let rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
|
||||||
|
|
||||||
let default_model = "openai/whisper-tiny.en".to_string();
|
let default_model = "openai/whisper-tiny.en".to_string();
|
||||||
let path = std::path::PathBuf::from(default_model.clone());
|
let path = std::path::PathBuf::from(default_model.clone());
|
||||||
let default_revision = "refs/pr/15".to_string();
|
let default_revision = "refs/pr/15".to_string();
|
||||||
@ -267,37 +327,10 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||||
let model = Whisper::load(&vb, config)?;
|
let model = Whisper::load(&vb, config)?;
|
||||||
let mut dc = Decode {
|
let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?;
|
||||||
model,
|
dc.run(&mel)?;
|
||||||
rng,
|
|
||||||
tokenizer,
|
|
||||||
};
|
|
||||||
|
|
||||||
let (_, _, content_frames) = mel.shape().r3()?;
|
|
||||||
let mut seek = 0;
|
|
||||||
let mut segments = vec![];
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
while seek < content_frames {
|
|
||||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
|
||||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
|
||||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
|
||||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
|
||||||
let dr = dc.decode_with_fallback(&mel_segment)?;
|
|
||||||
seek += segment_size;
|
|
||||||
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
|
||||||
println!("no speech detected, skipping {seek} {dr:?}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let segment = Segment {
|
|
||||||
start: time_offset,
|
|
||||||
duration: segment_duration,
|
|
||||||
dr,
|
|
||||||
};
|
|
||||||
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
|
|
||||||
segments.push(segment)
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ impl<'a> VarBuilder<'a> {
|
|||||||
pub fn from_safetensors(
|
pub fn from_safetensors(
|
||||||
safetensors: Vec<SafeTensors<'a>>,
|
safetensors: Vec<SafeTensors<'a>>,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
device: Device,
|
device: &Device,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let mut routing = HashMap::new();
|
let mut routing = HashMap::new();
|
||||||
for (index, sf) in safetensors.iter().enumerate() {
|
for (index, sf) in safetensors.iter().enumerate() {
|
||||||
@ -25,7 +25,7 @@ impl<'a> VarBuilder<'a> {
|
|||||||
}
|
}
|
||||||
Self {
|
Self {
|
||||||
safetensors: Some((routing, safetensors)),
|
safetensors: Some((routing, safetensors)),
|
||||||
device,
|
device: device.clone(),
|
||||||
dtype,
|
dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user