diff --git a/candle-examples/examples/whisper/audio.rs b/candle-examples/examples/whisper/audio.rs index d50b7923..d095e239 100644 --- a/candle-examples/examples/whisper/audio.rs +++ b/candle-examples/examples/whisper/audio.rs @@ -148,7 +148,7 @@ fn log_mel_spectrogram_w( mel } -fn log_mel_spectrogram_( +fn log_mel_spectrogram_( samples: &[T], filters: &[T], fft_size: usize, @@ -200,20 +200,17 @@ fn log_mel_spectrogram_( mel } -pub fn pcm_to_mel( +pub fn pcm_to_mel( samples: &[T], filters: &[T], - n_mel: usize, - n_fft: usize, ) -> anyhow::Result> { - if filters.len() != n_mel * n_fft { - anyhow::bail!( - "unexpected filter length {} (n_mel: {}, n_fft: {})", - filters.len(), - n_mel, - n_fft - ) - } - let mel = log_mel_spectrogram_(samples, filters, n_fft, super::HOP_LENGTH, n_mel, false); + let mel = log_mel_spectrogram_( + samples, + filters, + super::N_FFT, + super::HOP_LENGTH, + super::N_MELS, + false, + ); Ok(mel) } diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 6e15fa8a..6ea3e536 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -201,7 +201,6 @@ fn main() -> Result<()> { let mel_filters = mel_filters.deserialize()?; let mel_filters = mel_filters.tensor("mel_80", &device)?; println!("loaded mel filters {:?}", mel_filters.shape()); - let (n_mel, n_fft) = mel_filters.shape().r2()?; let mel_filters = mel_filters.flatten_all()?.to_vec1::()?; let mut input = std::fs::File::open(args.input)?; @@ -215,9 +214,10 @@ fn main() -> Result<()> { .iter() .map(|v| *v as f32 / 32768.) .collect(); - let mel = audio::pcm_to_mel(&pcm_data, &mel_filters, n_mel, n_fft)?; + println!("pcm data loaded {}", pcm_data.len()); + let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?; let mel_len = mel.len(); - let mel = Tensor::from_vec(mel, (1, n_mel, mel_len / n_mel), &device)?; + let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; println!("loaded mel: {:?}", mel.dims()); let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };