Put everything together.

This commit is contained in:
laurent
2023-07-05 12:19:21 +01:00
parent 95f378ebb4
commit 63e5a266bf
2 changed files with 17 additions and 26 deletions

View File

@ -201,24 +201,23 @@ fn main() -> Result<()> {
let mel_filters = mel_filters.deserialize()?;
let mel_filters = mel_filters.tensor("mel_80", &device)?;
println!("loaded mel filters {:?}", mel_filters.shape());
let (n_mel, n_fft) = mel_filters.shape().r2()?;
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let mut input = std::fs::File::open(args.input)?;
let (header, data) = wav::read(&mut input)?;
println!("loaded wav data: {header:?}");
if header.sampling_rate != audio::WHISPER_SAMPLE_RATE as u32 {
anyhow::bail!(
"wav file must have a {} sampling rate",
audio::WHISPER_SAMPLE_RATE
)
if header.sampling_rate != SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
}
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
let mel = Tensor::new(&mel[..], &device)?;
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters, n_mel, n_fft)?;
let mel_len = mel.len();
let mel = Tensor::from_vec(mel, (1, n_mel, mel_len / n_mel), &device)?;
println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };