Put everything together.

This commit is contained in:
laurent
2023-07-05 12:19:21 +01:00
parent 95f378ebb4
commit 63e5a266bf
2 changed files with 17 additions and 26 deletions

View File

@ -1,12 +1,6 @@
// Audio processing code, adapted from whisper.cpp
// https://github.com/ggerganov/whisper.cpp
pub const WHISPER_SAMPLE_RATE: usize = 16000;
pub const WHISPER_N_FFT: usize = 400;
pub const WHISPER_N_MEL: usize = 80;
pub const WHISPER_HOP_LENGTH: usize = 160;
pub const WHISPER_CHUNK_SIZE: usize = 30;
pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
impl Float for f32 {}
@ -175,7 +169,7 @@ fn log_mel_spectrogram_<T: Float>(
let n_len = samples.len() / fft_step;
// pad audio with at least one extra chunk of zeros
let pad = 100 * WHISPER_CHUNK_SIZE / 2;
let pad = 100 * super::CHUNK_LENGTH / 2;
let n_len = if n_len % pad != 0 {
(n_len / pad + 1) * pad
} else {
@ -206,22 +200,20 @@ fn log_mel_spectrogram_<T: Float>(
mel
}
pub fn pcm_to_mel<T: Float>(samples: &[T], filters: &[T]) -> anyhow::Result<Vec<T>> {
if filters.len() != WHISPER_N_MEL * WHISPER_N_FFT {
pub fn pcm_to_mel<T: Float>(
samples: &[T],
filters: &[T],
n_mel: usize,
n_fft: usize,
) -> anyhow::Result<Vec<T>> {
if filters.len() != n_mel * n_fft {
anyhow::bail!(
"unexpected filter length {} (n_mel: {}, n_fft: {})",
filters.len(),
WHISPER_N_MEL,
WHISPER_N_FFT
n_mel,
n_fft
)
}
let mel = log_mel_spectrogram_(
samples,
filters,
WHISPER_N_FFT,
WHISPER_HOP_LENGTH,
WHISPER_N_MEL,
false,
);
let mel = log_mel_spectrogram_(samples, filters, n_fft, super::HOP_LENGTH, n_mel, false);
Ok(mel)
}