mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
80 lines
2.2 KiB
Rust
80 lines
2.2 KiB
Rust
// Audio processing code, adapted from whisper.cpp
|
|
// https://github.com/ggerganov/whisper.cpp
|
|
|
|
trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
|
|
|
|
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
|
|
fn fft<T: Float>(inp: &[T]) -> Vec<T> {
|
|
let n = inp.len();
|
|
let zero = T::zero();
|
|
if n == 1 {
|
|
return vec![inp[0], zero];
|
|
}
|
|
if n % 2 == 1 {
|
|
return dft(inp);
|
|
}
|
|
let mut out = vec![zero; n * 2];
|
|
|
|
let mut even = vec![];
|
|
even.reserve(n / 2);
|
|
let mut odd = vec![];
|
|
odd.reserve(n / 2);
|
|
|
|
for (i, &inp) in inp.iter().enumerate() {
|
|
if i % 2 == 0 {
|
|
even.push(inp)
|
|
} else {
|
|
odd.push(inp);
|
|
}
|
|
}
|
|
|
|
let even_fft = fft(&even);
|
|
let odd_fft = fft(&odd);
|
|
|
|
let two_pi = T::PI() + T::PI();
|
|
let n_t = T::from(n).unwrap();
|
|
for k in 0..n / 2 {
|
|
let k_t = T::from(k).unwrap();
|
|
let theta = two_pi * k_t / n_t;
|
|
let re = theta.cos();
|
|
let im = -theta.sin();
|
|
|
|
let re_odd = odd_fft[2 * k];
|
|
let im_odd = odd_fft[2 * k + 1];
|
|
|
|
out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
|
|
out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
|
|
|
|
out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
|
|
out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
|
|
}
|
|
out
|
|
}
|
|
|
|
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337
|
|
fn dft<T: Float>(inp: &[T]) -> Vec<T> {
|
|
let zero = T::zero();
|
|
let n = inp.len();
|
|
let two_pi = T::PI() + T::PI();
|
|
|
|
let mut out = Vec::new();
|
|
out.reserve(2 * n);
|
|
let n_t = T::from(n).unwrap();
|
|
for k in 0..n {
|
|
let k_t = T::from(k).unwrap();
|
|
let mut re = zero;
|
|
let mut im = zero;
|
|
|
|
for (j, &inp) in inp.iter().enumerate() {
|
|
let j_t = T::from(j).unwrap();
|
|
let angle = two_pi * k_t * j_t / n_t;
|
|
re += inp * angle.cos();
|
|
im -= inp * angle.sin();
|
|
}
|
|
|
|
out.push(re);
|
|
out.push(im);
|
|
}
|
|
out
|
|
}
|