Merge pull request #76 from LaurentMazare/whisper-mel-bugfix

Bugfix for the mel filters in whisper.
This commit is contained in:
Laurent Mazare
2023-07-05 12:57:13 +01:00
committed by GitHub
2 changed files with 13 additions and 16 deletions

View File

@ -148,7 +148,7 @@ fn log_mel_spectrogram_w<T: Float>(
mel mel
} }
fn log_mel_spectrogram_<T: Float>( fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
samples: &[T], samples: &[T],
filters: &[T], filters: &[T],
fft_size: usize, fft_size: usize,
@ -200,20 +200,17 @@ fn log_mel_spectrogram_<T: Float>(
mel mel
} }
pub fn pcm_to_mel<T: Float>( pub fn pcm_to_mel<T: Float + std::fmt::Display>(
samples: &[T], samples: &[T],
filters: &[T], filters: &[T],
n_mel: usize,
n_fft: usize,
) -> anyhow::Result<Vec<T>> { ) -> anyhow::Result<Vec<T>> {
if filters.len() != n_mel * n_fft { let mel = log_mel_spectrogram_(
anyhow::bail!( samples,
"unexpected filter length {} (n_mel: {}, n_fft: {})", filters,
filters.len(), super::N_FFT,
n_mel, super::HOP_LENGTH,
n_fft super::N_MELS,
) false,
} );
let mel = log_mel_spectrogram_(samples, filters, n_fft, super::HOP_LENGTH, n_mel, false);
Ok(mel) Ok(mel)
} }

View File

@ -201,7 +201,6 @@ fn main() -> Result<()> {
let mel_filters = mel_filters.deserialize()?; let mel_filters = mel_filters.deserialize()?;
let mel_filters = mel_filters.tensor("mel_80", &device)?; let mel_filters = mel_filters.tensor("mel_80", &device)?;
println!("loaded mel filters {:?}", mel_filters.shape()); println!("loaded mel filters {:?}", mel_filters.shape());
let (n_mel, n_fft) = mel_filters.shape().r2()?;
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?; let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let mut input = std::fs::File::open(args.input)?; let mut input = std::fs::File::open(args.input)?;
@ -215,9 +214,10 @@ fn main() -> Result<()> {
.iter() .iter()
.map(|v| *v as f32 / 32768.) .map(|v| *v as f32 / 32768.)
.collect(); .collect();
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters, n_mel, n_fft)?; println!("pcm data loaded {}", pcm_data.len());
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
let mel_len = mel.len(); let mel_len = mel.len();
let mel = Tensor::from_vec(mel, (1, n_mel, mel_len / n_mel), &device)?; let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims()); println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };