mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Put everything together.
This commit is contained in:
@ -1,12 +1,6 @@
|
||||
// Audio processing code, adapted from whisper.cpp
|
||||
// https://github.com/ggerganov/whisper.cpp
|
||||
|
||||
pub const WHISPER_SAMPLE_RATE: usize = 16000;
|
||||
pub const WHISPER_N_FFT: usize = 400;
|
||||
pub const WHISPER_N_MEL: usize = 80;
|
||||
pub const WHISPER_HOP_LENGTH: usize = 160;
|
||||
pub const WHISPER_CHUNK_SIZE: usize = 30;
|
||||
|
||||
pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
|
||||
|
||||
impl Float for f32 {}
|
||||
@ -175,7 +169,7 @@ fn log_mel_spectrogram_<T: Float>(
|
||||
let n_len = samples.len() / fft_step;
|
||||
|
||||
// pad audio with at least one extra chunk of zeros
|
||||
let pad = 100 * WHISPER_CHUNK_SIZE / 2;
|
||||
let pad = 100 * super::CHUNK_LENGTH / 2;
|
||||
let n_len = if n_len % pad != 0 {
|
||||
(n_len / pad + 1) * pad
|
||||
} else {
|
||||
@ -206,22 +200,20 @@ fn log_mel_spectrogram_<T: Float>(
|
||||
mel
|
||||
}
|
||||
|
||||
pub fn pcm_to_mel<T: Float>(samples: &[T], filters: &[T]) -> anyhow::Result<Vec<T>> {
|
||||
if filters.len() != WHISPER_N_MEL * WHISPER_N_FFT {
|
||||
pub fn pcm_to_mel<T: Float>(
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
n_mel: usize,
|
||||
n_fft: usize,
|
||||
) -> anyhow::Result<Vec<T>> {
|
||||
if filters.len() != n_mel * n_fft {
|
||||
anyhow::bail!(
|
||||
"unexpected filter length {} (n_mel: {}, n_fft: {})",
|
||||
filters.len(),
|
||||
WHISPER_N_MEL,
|
||||
WHISPER_N_FFT
|
||||
n_mel,
|
||||
n_fft
|
||||
)
|
||||
}
|
||||
let mel = log_mel_spectrogram_(
|
||||
samples,
|
||||
filters,
|
||||
WHISPER_N_FFT,
|
||||
WHISPER_HOP_LENGTH,
|
||||
WHISPER_N_MEL,
|
||||
false,
|
||||
);
|
||||
let mel = log_mel_spectrogram_(samples, filters, n_fft, super::HOP_LENGTH, n_mel, false);
|
||||
Ok(mel)
|
||||
}
|
||||
|
@ -201,24 +201,23 @@ fn main() -> Result<()> {
|
||||
let mel_filters = mel_filters.deserialize()?;
|
||||
let mel_filters = mel_filters.tensor("mel_80", &device)?;
|
||||
println!("loaded mel filters {:?}", mel_filters.shape());
|
||||
let (n_mel, n_fft) = mel_filters.shape().r2()?;
|
||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||
|
||||
let mut input = std::fs::File::open(args.input)?;
|
||||
let (header, data) = wav::read(&mut input)?;
|
||||
println!("loaded wav data: {header:?}");
|
||||
if header.sampling_rate != audio::WHISPER_SAMPLE_RATE as u32 {
|
||||
anyhow::bail!(
|
||||
"wav file must have a {} sampling rate",
|
||||
audio::WHISPER_SAMPLE_RATE
|
||||
)
|
||||
if header.sampling_rate != SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
|
||||
}
|
||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||
.iter()
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
|
||||
let mel = Tensor::new(&mel[..], &device)?;
|
||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters, n_mel, n_fft)?;
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(mel, (1, n_mel, mel_len / n_mel), &device)?;
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||
|
Reference in New Issue
Block a user