diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 4007bd8d..f2560815 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -345,7 +345,7 @@ enum Task { Translate, } -#[derive(Clone, Copy, Debug, ValueEnum)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)] enum WhichModel { Tiny, #[value(name = "tiny.en")] @@ -361,6 +361,7 @@ enum WhichModel { MediumEn, Large, LargeV2, + LargeV3, #[value(name = "distil-medium.en")] DistilMediumEn, #[value(name = "distil-large-v2")] @@ -376,6 +377,7 @@ impl WhichModel { | Self::Medium | Self::Large | Self::LargeV2 + | Self::LargeV3 | Self::DistilLargeV2 => true, Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => { false @@ -395,6 +397,7 @@ impl WhichModel { Self::MediumEn => ("openai/whisper-medium.en", "main"), Self::Large => ("openai/whisper-large", "refs/pr/36"), Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"), + Self::LargeV3 => ("openai/whisper-large-v3", "main"), Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"), Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"), } @@ -508,14 +511,18 @@ fn main() -> Result<()> { repo.get(&format!("model-{ext}-q80.gguf"))?, ) } else { - ( - repo.get("config.json")?, - repo.get("tokenizer.json")?, - repo.get("model.safetensors")?, - ) + let config = repo.get("config.json")?; + let tokenizer = if args.model == WhichModel::LargeV3 { + panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment") + } else { + repo.get("tokenizer.json")? + }; + let model = repo.get("model.safetensors")?; + (config, tokenizer, model) }; (config, tokenizer, model, sample) }; + let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let mel_bytes = include_bytes!("melfilters.bytes"); @@ -534,12 +541,15 @@ fn main() -> Result<()> { .map(|v| *v as f32 / 32768.) .collect(); println!("pcm data loaded {}", pcm_data.len()); - let mel = audio::pcm_to_mel(&pcm_data, &mel_filters); + let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters); let mel_len = mel.len(); - let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?; + let mel = Tensor::from_vec( + mel, + (1, config.num_mel_bins, mel_len / config.num_mel_bins), + &device, + )?; println!("loaded mel: {:?}", mel.dims()); - let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let mut model = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?; diff --git a/candle-transformers/src/models/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs index 4e01de32..92ebd954 100644 --- a/candle-transformers/src/models/whisper/audio.rs +++ b/candle-transformers/src/models/whisper/audio.rs @@ -198,13 +198,17 @@ fn log_mel_spectrogram_( mel } -pub fn pcm_to_mel(samples: &[T], filters: &[T]) -> Vec { +pub fn pcm_to_mel( + cfg: &super::Config, + samples: &[T], + filters: &[T], +) -> Vec { log_mel_spectrogram_( samples, filters, super::N_FFT, super::HOP_LENGTH, - super::N_MELS, + cfg.num_mel_bins, false, ) } diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs index 35d35e77..bf24045a 100644 --- a/candle-transformers/src/models/whisper/mod.rs +++ b/candle-transformers/src/models/whisper/mod.rs @@ -18,6 +18,7 @@ pub struct Config { // pub n_text_state: usize, pub decoder_attention_heads: usize, // n_text_head pub decoder_layers: usize, // n_text_layer + #[serde(default)] pub suppress_tokens: Vec, } @@ -26,7 +27,6 @@ pub const DTYPE: candle::DType = candle::DType::F32; // Audio parameters. pub const SAMPLE_RATE: usize = 16000; pub const N_FFT: usize = 400; -pub const N_MELS: usize = 80; pub const HOP_LENGTH: usize = 160; pub const CHUNK_LENGTH: usize = 30; pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk diff --git a/candle-wasm-examples/whisper/src/audio.rs b/candle-wasm-examples/whisper/src/audio.rs index 10974d15..a3ca6c73 100644 --- a/candle-wasm-examples/whisper/src/audio.rs +++ b/candle-wasm-examples/whisper/src/audio.rs @@ -200,6 +200,7 @@ fn log_mel_spectrogram_( } pub fn pcm_to_mel( + cfg: &worker::m::Config, samples: &[T], filters: &[T], ) -> anyhow::Result> { @@ -208,7 +209,7 @@ pub fn pcm_to_mel( filters, worker::m::N_FFT, worker::m::HOP_LENGTH, - worker::m::N_MELS, + cfg.num_mel_bins, false, ); Ok(mel) diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index a8646e3d..09d4f580 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -349,9 +349,10 @@ impl Decoder { .map(|v| *v as f32 / 32768.) .collect(); console_log!("pcm data loaded {}", pcm_data.len()); - let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?; + let mel = crate::audio::pcm_to_mel(self.model.config(), &pcm_data, &self.mel_filters)?; let mel_len = mel.len(); - let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?; + let n_mels = self.model.config().num_mel_bins; + let mel = Tensor::from_vec(mel, (1, n_mels, mel_len / n_mels), &device)?; console_log!("loaded mel: {:?}", mel.dims()); let segments = self.run(&mel)?; Ok(segments)