diff --git a/candle-examples/examples/whisper/audio.rs b/candle-examples/examples/whisper/audio.rs index dbf5d2c0..d50b7923 100644 --- a/candle-examples/examples/whisper/audio.rs +++ b/candle-examples/examples/whisper/audio.rs @@ -1,12 +1,6 @@ // Audio processing code, adapted from whisper.cpp // https://github.com/ggerganov/whisper.cpp -pub const WHISPER_SAMPLE_RATE: usize = 16000; -pub const WHISPER_N_FFT: usize = 400; -pub const WHISPER_N_MEL: usize = 80; -pub const WHISPER_HOP_LENGTH: usize = 160; -pub const WHISPER_CHUNK_SIZE: usize = 30; - pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} impl Float for f32 {} @@ -175,7 +169,7 @@ fn log_mel_spectrogram_( let n_len = samples.len() / fft_step; // pad audio with at least one extra chunk of zeros - let pad = 100 * WHISPER_CHUNK_SIZE / 2; + let pad = 100 * super::CHUNK_LENGTH / 2; let n_len = if n_len % pad != 0 { (n_len / pad + 1) * pad } else { @@ -206,22 +200,20 @@ fn log_mel_spectrogram_( mel } -pub fn pcm_to_mel(samples: &[T], filters: &[T]) -> anyhow::Result> { - if filters.len() != WHISPER_N_MEL * WHISPER_N_FFT { +pub fn pcm_to_mel( + samples: &[T], + filters: &[T], + n_mel: usize, + n_fft: usize, +) -> anyhow::Result> { + if filters.len() != n_mel * n_fft { anyhow::bail!( "unexpected filter length {} (n_mel: {}, n_fft: {})", filters.len(), - WHISPER_N_MEL, - WHISPER_N_FFT + n_mel, + n_fft ) } - let mel = log_mel_spectrogram_( - samples, - filters, - WHISPER_N_FFT, - WHISPER_HOP_LENGTH, - WHISPER_N_MEL, - false, - ); + let mel = log_mel_spectrogram_(samples, filters, n_fft, super::HOP_LENGTH, n_mel, false); Ok(mel) } diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 231feac8..6e15fa8a 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -201,24 +201,23 @@ fn main() -> Result<()> { let mel_filters = mel_filters.deserialize()?; let mel_filters = mel_filters.tensor("mel_80", &device)?; println!("loaded mel filters {:?}", mel_filters.shape()); + let (n_mel, n_fft) = mel_filters.shape().r2()?; let mel_filters = mel_filters.flatten_all()?.to_vec1::()?; let mut input = std::fs::File::open(args.input)?; let (header, data) = wav::read(&mut input)?; println!("loaded wav data: {header:?}"); - if header.sampling_rate != audio::WHISPER_SAMPLE_RATE as u32 { - anyhow::bail!( - "wav file must have a {} sampling rate", - audio::WHISPER_SAMPLE_RATE - ) + if header.sampling_rate != SAMPLE_RATE as u32 { + anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE) } let data = data.as_sixteen().expect("expected 16 bit wav file"); let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] .iter() .map(|v| *v as f32 / 32768.) .collect(); - let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?; - let mel = Tensor::new(&mel[..], &device)?; + let mel = audio::pcm_to_mel(&pcm_data, &mel_filters, n_mel, n_fft)?; + let mel_len = mel.len(); + let mel = Tensor::from_vec(mel, (1, n_mel, mel_len / n_mel), &device)?; println!("loaded mel: {:?}", mel.dims()); let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };