diff --git a/candle-examples/examples/whisper/audio.rs b/candle-examples/examples/whisper/audio.rs index aca4607f..4b5ed181 100644 --- a/candle-examples/examples/whisper/audio.rs +++ b/candle-examples/examples/whisper/audio.rs @@ -1,6 +1,8 @@ // Audio processing code, adapted from whisper.cpp // https://github.com/ggerganov/whisper.cpp +const WHISPER_CHUNK_SIZE: usize = 30; + trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} // https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357 @@ -91,7 +93,7 @@ fn log_mel_spectrogram_w( n_len: usize, n_mel: usize, n_threads: usize, -) { +) -> Vec { let n_fft = if speed_up { 1 + fft_size / 4 } else { @@ -142,4 +144,57 @@ fn log_mel_spectrogram_w( mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10(); } } + mel +} + +fn log_mel_spectrogram( + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + n_mel: usize, + speed_up: bool, +) -> Vec { + let zero = T::zero(); + let two_pi = T::PI() + T::PI(); + let half = T::from(0.5).unwrap(); + let one = T::from(1.0).unwrap(); + let four = T::from(4.0).unwrap(); + let fft_size_t = T::from(fft_size).unwrap(); + + let hann: Vec = (0..fft_size) + .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos())) + .collect(); + let n_len = samples.len() / fft_step; + + // pad audio with at least one extra chunk of zeros + let pad = 100 * WHISPER_CHUNK_SIZE / 2; + let n_len = if n_len % pad != 0 { + (n_len / pad + 1) * pad + } else { + n_len + }; + let n_len = n_len + pad; + let samples = { + let mut samples_padded = samples.to_vec(); + let to_add = n_len * fft_step - samples.len(); + samples_padded.extend(std::iter::repeat(zero).take(to_add)); + samples_padded + }; + + // Use a single thread for now. + let mut mel = log_mel_spectrogram_w( + 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1, + ); + let mmax = mel + .iter() + .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater)) + .copied() + .unwrap_or(zero) + - T::from(8).unwrap(); + for m in mel.iter_mut() { + let v = T::max(*m, mmax); + *m = v / four + one + } + mel }