diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 00340d08..af5f791a 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -21,7 +21,7 @@ candle-onnx = { workspace = true, optional = true } csv = "1.3.0" cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } -hf-hub = { workspace = true, features=["tokio"]} +hf-hub = { workspace = true, features = ["tokio"] } image = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } @@ -30,6 +30,7 @@ rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +symphonia = { version = "0.5.3", features = ["all"] } tokenizers = { workspace = true, features = ["onig"] } [dev-dependencies] @@ -43,7 +44,6 @@ rusttype = { workspace = true } tracing = { workspace = true } tracing-chrome = { workspace = true } tracing-subscriber = { workspace = true } -wav = { workspace = true } # Necessary to disambiguate with tokio in wasm examples which are 1.28.1 tokio = "1.29.1" diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 6ea34613..da8c73ae 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -18,6 +18,8 @@ use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; mod multilingual; +mod pcm_decode; + use candle_transformers::models::whisper::{self as m, audio, Config}; pub enum Model { @@ -535,17 +537,10 @@ fn main() -> Result<()> { let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; ::read_f32_into(mel_bytes, &mut mel_filters); - let mut input = std::fs::File::open(input)?; - let (header, data) = wav::read(&mut input)?; - println!("loaded wav data: {header:?}"); - if header.sampling_rate != m::SAMPLE_RATE as u32 { - anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE) + let (pcm_data, sample_rate) = pcm_decode::pcm_decode(input)?; + if sample_rate != m::SAMPLE_RATE as u32 { + anyhow::bail!("input file must have a {} sampling rate", m::SAMPLE_RATE) } - let data = data.as_sixteen().expect("expected 16 bit wav file"); - let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] - .iter() - .map(|v| *v as f32 / 32768.) - .collect(); println!("pcm data loaded {}", pcm_data.len()); let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters); let mel_len = mel.len(); diff --git a/candle-examples/examples/whisper/pcm_decode.rs b/candle-examples/examples/whisper/pcm_decode.rs new file mode 100644 index 00000000..e75d3ffd --- /dev/null +++ b/candle-examples/examples/whisper/pcm_decode.rs @@ -0,0 +1,74 @@ +use symphonia::core::audio::{AudioBufferRef, Signal}; +use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL}; +use symphonia::core::conv::FromSample; + +fn conv(samples: &mut Vec, data: std::borrow::Cow>) +where + T: symphonia::core::sample::Sample, + f32: symphonia::core::conv::FromSample, +{ + samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v))) +} + +pub(crate) fn pcm_decode>(path: P) -> anyhow::Result<(Vec, u32)> { + // Open the media source. + let src = std::fs::File::open(path)?; + + // Create the media source stream. + let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default()); + + // Create a probe hint using the file's extension. [Optional] + let hint = symphonia::core::probe::Hint::new(); + + // Use the default options for metadata and format readers. + let meta_opts: symphonia::core::meta::MetadataOptions = Default::default(); + let fmt_opts: symphonia::core::formats::FormatOptions = Default::default(); + + // Probe the media source. + let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?; + // Get the instantiated format reader. + let mut format = probed.format; + + // Find the first audio track with a known (decodeable) codec. + let track = format + .tracks() + .iter() + .find(|t| t.codec_params.codec != CODEC_TYPE_NULL) + .expect("no supported audio tracks"); + + // Use the default options for the decoder. + let dec_opts: DecoderOptions = Default::default(); + + // Create a decoder for the track. + let mut decoder = symphonia::default::get_codecs() + .make(&track.codec_params, &dec_opts) + .expect("unsupported codec"); + let track_id = track.id; + let sample_rate = track.codec_params.sample_rate.unwrap_or(0); + let mut pcm_data = Vec::new(); + // The decode loop. + while let Ok(packet) = format.next_packet() { + // Consume any new metadata that has been read since the last packet. + while !format.metadata().is_latest() { + format.metadata().pop(); + } + + // If the packet does not belong to the selected track, skip over it. + if packet.track_id() != track_id { + continue; + } + match decoder.decode(&packet)? { + AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)), + AudioBufferRef::U8(data) => conv(&mut pcm_data, data), + AudioBufferRef::U16(data) => conv(&mut pcm_data, data), + AudioBufferRef::U24(data) => conv(&mut pcm_data, data), + AudioBufferRef::U32(data) => conv(&mut pcm_data, data), + AudioBufferRef::S8(data) => conv(&mut pcm_data, data), + AudioBufferRef::S16(data) => conv(&mut pcm_data, data), + AudioBufferRef::S24(data) => conv(&mut pcm_data, data), + AudioBufferRef::S32(data) => conv(&mut pcm_data, data), + AudioBufferRef::F64(data) => conv(&mut pcm_data, data), + } + } + Ok((pcm_data, sample_rate)) +} diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 1a72c36a..0e55ab8c 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -23,7 +23,6 @@ serde = { workspace = true } serde_json = { workspace = true } serde_plain = { workspace = true } tracing = { workspace = true } -wav = { workspace = true } [features] default = []