mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
feat: support multithread spectrogram and small perf tweaks (#1674)
* feat: support multithread spectrogram and small perf tweaks * feat: clippy improvement for loop variable * fix: add back speed up scale down logic * fix: readd mirroring logic * feat: prefer scoped thread and simplify/improve logic/traits
This commit is contained in:
@ -1,7 +1,14 @@
|
||||
// Audio processing code, adapted from whisper.cpp
|
||||
// https://github.com/ggerganov/whisper.cpp
|
||||
|
||||
pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
|
||||
use candle::utils::get_num_threads;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
pub trait Float:
|
||||
num_traits::Float + num_traits::FloatConst + num_traits::NumAssign + Send + Sync
|
||||
{
|
||||
}
|
||||
|
||||
impl Float for f32 {}
|
||||
impl Float for f64 {}
|
||||
@ -102,22 +109,26 @@ fn log_mel_spectrogram_w<T: Float>(
|
||||
let half = T::from(0.5).unwrap();
|
||||
let mut fft_in = vec![zero; fft_size];
|
||||
let mut mel = vec![zero; n_len * n_mel];
|
||||
let n_samples = samples.len();
|
||||
let end = std::cmp::min(n_samples / fft_step + 1, n_len);
|
||||
|
||||
for i in (ith..n_len).step_by(n_threads) {
|
||||
for i in (ith..end).step_by(n_threads) {
|
||||
let offset = i * fft_step;
|
||||
|
||||
// apply Hanning window
|
||||
for j in 0..fft_size {
|
||||
fft_in[j] = if offset + j < samples.len() {
|
||||
hann[j] * samples[offset + j]
|
||||
} else {
|
||||
zero
|
||||
}
|
||||
for j in 0..std::cmp::min(fft_size, n_samples - offset) {
|
||||
fft_in[j] = hann[j] * samples[offset + j];
|
||||
}
|
||||
|
||||
// FFT -> mag^2
|
||||
// fill the rest with zeros
|
||||
if n_samples - offset < fft_size {
|
||||
fft_in[n_samples - offset..].fill(zero);
|
||||
}
|
||||
|
||||
// FFT
|
||||
let mut fft_out: Vec<T> = fft(&fft_in);
|
||||
|
||||
// Calculate modulus^2 of complex numbers
|
||||
for j in 0..fft_size {
|
||||
fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
|
||||
}
|
||||
@ -136,8 +147,19 @@ fn log_mel_spectrogram_w<T: Float>(
|
||||
// mel spectrogram
|
||||
for j in 0..n_mel {
|
||||
let mut sum = zero;
|
||||
for k in 0..n_fft {
|
||||
let mut k = 0;
|
||||
// Unroll loop
|
||||
while k < n_fft.saturating_sub(3) {
|
||||
sum += fft_out[k] * filters[j * n_fft + k]
|
||||
+ fft_out[k + 1] * filters[j * n_fft + k + 1]
|
||||
+ fft_out[k + 2] * filters[j * n_fft + k + 2]
|
||||
+ fft_out[k + 3] * filters[j * n_fft + k + 3];
|
||||
k += 4;
|
||||
}
|
||||
// Handle remainder
|
||||
while k < n_fft {
|
||||
sum += fft_out[k] * filters[j * n_fft + k];
|
||||
k += 1;
|
||||
}
|
||||
mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
|
||||
}
|
||||
@ -145,7 +167,7 @@ fn log_mel_spectrogram_w<T: Float>(
|
||||
mel
|
||||
}
|
||||
|
||||
fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
fn log_mel_spectrogram_<T: Float>(
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
fft_size: usize,
|
||||
@ -180,10 +202,55 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
samples_padded
|
||||
};
|
||||
|
||||
// Use a single thread for now.
|
||||
let mut mel = log_mel_spectrogram_w(
|
||||
0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
|
||||
);
|
||||
// ensure that the number of threads is even and less than 12
|
||||
let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12);
|
||||
|
||||
let hann = Arc::new(hann);
|
||||
let samples = Arc::new(samples);
|
||||
let filters = Arc::new(filters);
|
||||
|
||||
// use scope to allow for non static references to be passed to the threads
|
||||
// and directly collect the results into a single vector
|
||||
let all_outputs = thread::scope(|s| {
|
||||
(0..n_threads)
|
||||
// create threads and return their handles
|
||||
.map(|thread_id| {
|
||||
let hann = Arc::clone(&hann);
|
||||
let samples = Arc::clone(&samples);
|
||||
let filters = Arc::clone(&filters);
|
||||
// spawn new thread and start work
|
||||
s.spawn(move || {
|
||||
log_mel_spectrogram_w(
|
||||
thread_id, &hann, &samples, &filters, fft_size, fft_step, speed_up, n_len,
|
||||
n_mel, n_threads,
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
// wait for each thread to finish and collect their results
|
||||
.map(|handle| handle.join().expect("Thread failed"))
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
let l = all_outputs[0].len();
|
||||
let mut mel = vec![zero; l];
|
||||
|
||||
// iterate over mel spectrogram segments, dividing work by threads.
|
||||
for segment_start in (0..l).step_by(n_threads) {
|
||||
// go through each thread's output.
|
||||
for thread_output in all_outputs.iter() {
|
||||
// add each thread's piece to our mel spectrogram.
|
||||
for offset in 0..n_threads {
|
||||
let mel_index = segment_start + offset; // find location in mel.
|
||||
if mel_index < mel.len() {
|
||||
// Make sure we don't go out of bounds.
|
||||
mel[mel_index] += thread_output[mel_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mmax = mel
|
||||
.iter()
|
||||
.max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
|
||||
@ -197,11 +264,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
mel
|
||||
}
|
||||
|
||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
cfg: &super::Config,
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
) -> Vec<T> {
|
||||
pub fn pcm_to_mel<T: Float>(cfg: &super::Config, samples: &[T], filters: &[T]) -> Vec<T> {
|
||||
log_mel_spectrogram_(
|
||||
samples,
|
||||
filters,
|
||||
@ -211,3 +274,62 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_fft() {
|
||||
let input = vec![0.0, 1.0, 0.0, 0.0];
|
||||
let output = fft(&input);
|
||||
assert_eq!(
|
||||
output,
|
||||
vec![
|
||||
1.0,
|
||||
0.0,
|
||||
6.123233995736766e-17,
|
||||
-1.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-6.123233995736766e-17,
|
||||
1.0
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dft() {
|
||||
let input = vec![0.0, 1.0, 0.0, 0.0];
|
||||
let output = dft(&input);
|
||||
assert_eq!(
|
||||
output,
|
||||
vec![
|
||||
1.0,
|
||||
0.0,
|
||||
6.123233995736766e-17,
|
||||
-1.0,
|
||||
-1.0,
|
||||
-1.2246467991473532e-16,
|
||||
-1.8369701987210297e-16,
|
||||
1.0
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log_mel_spectrogram() {
|
||||
let samples = vec![0.0; 1000];
|
||||
let filters = vec![0.0; 1000];
|
||||
let output = log_mel_spectrogram_(&samples, &filters, 100, 10, 10, false);
|
||||
assert_eq!(output.len(), 30_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tiny_log_mel_spectrogram() {
|
||||
let samples = vec![0.0; 100];
|
||||
let filters = vec![0.0; 100];
|
||||
let output = log_mel_spectrogram_(&samples, &filters, 20, 2, 2, false);
|
||||
assert_eq!(output.len(), 6_000);
|
||||
}
|
||||
}
|
||||
|
@ -195,14 +195,14 @@ impl ResidualAttentionBlock {
|
||||
}
|
||||
}
|
||||
|
||||
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
|
||||
let max_timescale = 10000f32;
|
||||
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
||||
let inv_timescales: Vec<_> = (0..channels / 2)
|
||||
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||
.collect();
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
|
||||
let arange = Tensor::arange(0, length as u32, device)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.unsqueeze(1)?;
|
||||
let sh = (length, channels / 2);
|
||||
@ -246,7 +246,7 @@ impl AudioEncoder {
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
|
||||
let blocks = (0..cfg.encoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
|
||||
|
@ -191,14 +191,14 @@ impl ResidualAttentionBlock {
|
||||
}
|
||||
}
|
||||
|
||||
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
|
||||
let max_timescale = 10000f32;
|
||||
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
||||
let inv_timescales: Vec<_> = (0..channels / 2)
|
||||
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||
.collect();
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
|
||||
let arange = Tensor::arange(0, length as u32, device)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.unsqueeze(1)?;
|
||||
let sh = (length, channels / 2);
|
||||
@ -242,7 +242,7 @@ impl AudioEncoder {
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
|
||||
let blocks = (0..cfg.encoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
|
||||
|
Reference in New Issue
Block a user