From 95f378ebb41bdf5abbdb71167e3df9eab14789df Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 5 Jul 2023 11:53:58 +0100 Subject: [PATCH] Read wav files. --- candle-examples/Cargo.toml | 1 + candle-examples/examples/whisper/audio.rs | 17 +++-- candle-examples/examples/whisper/main.rs | 84 ++++++++++++++--------- 3 files changed, 61 insertions(+), 41 deletions(-) diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index ca910441..a3e64a17 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -21,6 +21,7 @@ clap = { version = "4.2.4", features = ["derive"] } rand = "0.8.5" tokenizers = { version = "0.13.3", default-features=false, features=["onig"] } tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] } +wav = "1.0.0" [features] default = ["cuda"] diff --git a/candle-examples/examples/whisper/audio.rs b/candle-examples/examples/whisper/audio.rs index 64858141..dbf5d2c0 100644 --- a/candle-examples/examples/whisper/audio.rs +++ b/candle-examples/examples/whisper/audio.rs @@ -1,13 +1,16 @@ // Audio processing code, adapted from whisper.cpp // https://github.com/ggerganov/whisper.cpp -const WHISPER_SAMPLE_RATE: usize = 16000; -const WHISPER_N_FFT: usize = 400; -const WHISPER_N_MEL: usize = 80; -const WHISPER_HOP_LENGTH: usize = 160; -const WHISPER_CHUNK_SIZE: usize = 30; +pub const WHISPER_SAMPLE_RATE: usize = 16000; +pub const WHISPER_N_FFT: usize = 400; +pub const WHISPER_N_MEL: usize = 80; +pub const WHISPER_HOP_LENGTH: usize = 160; +pub const WHISPER_CHUNK_SIZE: usize = 30; -trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} +pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} + +impl Float for f32 {} +impl Float for f64 {} // https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357 fn fft(inp: &[T]) -> Vec { @@ -203,7 +206,7 @@ fn log_mel_spectrogram_( mel } -fn pcm_to_mel(samples: &[T], filters: &[T]) -> anyhow::Result> { +pub fn pcm_to_mel(samples: &[T], filters: &[T]) -> anyhow::Result> { if filters.len() != WHISPER_N_MEL * WHISPER_N_FFT { anyhow::bail!( "unexpected filter length {} (n_mel: {}, n_fft: {})", diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 66bde0e8..231feac8 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -40,33 +40,6 @@ const EOT_TOKEN: u32 = 50256; const NO_SPEECH_TOKEN: u32 = 50361; const NO_TIMESTAMP_TOKEN: u32 = 50362; -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Run on CPU rather than on GPU. - #[arg(long)] - cpu: bool, - - #[arg(long)] - weights: String, - - #[arg(long)] - input: String, - - #[arg(long)] - tokenizer_config: String, - - /// The seed to use when generating random samples. - #[arg(long, default_value_t = 299792458)] - seed: u64, - - #[arg( - long, - default_value = "candle-examples/examples/whisper/mel_filters.safetensors" - )] - filters: String, -} - #[derive(Debug, Clone)] struct DecodingResult { tokens: Vec, @@ -184,6 +157,35 @@ impl Decode { } } +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + weights: String, + + /// The input to be processed, in wav formats. + #[arg(long)] + input: String, + + #[arg(long)] + tokenizer_config: String, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The mel filters in safetensors format. + #[arg( + long, + default_value = "candle-examples/examples/whisper/mel_filters.safetensors" + )] + filters: String, +} + fn main() -> Result<()> { let args = Args::parse(); let device = if args.cpu { @@ -195,14 +197,28 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; - let filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? }; - let filters = filters.deserialize()?; - let filters = filters.tensor("mel_80", &device)?; - println!("loaded mel filters {:?}", filters.shape()); + let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? }; + let mel_filters = mel_filters.deserialize()?; + let mel_filters = mel_filters.tensor("mel_80", &device)?; + println!("loaded mel filters {:?}", mel_filters.shape()); + let mel_filters = mel_filters.flatten_all()?.to_vec1::()?; - let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? }; - let input = input.deserialize()?; - let mel = input.tensor("mel", &device)?; + let mut input = std::fs::File::open(args.input)?; + let (header, data) = wav::read(&mut input)?; + println!("loaded wav data: {header:?}"); + if header.sampling_rate != audio::WHISPER_SAMPLE_RATE as u32 { + anyhow::bail!( + "wav file must have a {} sampling rate", + audio::WHISPER_SAMPLE_RATE + ) + } + let data = data.as_sixteen().expect("expected 16 bit wav file"); + let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] + .iter() + .map(|v| *v as f32 / 32768.) + .collect(); + let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?; + let mel = Tensor::new(&mel[..], &device)?; println!("loaded mel: {:?}", mel.dims()); let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };