mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Supports more audio formats (#1628)
* Supports more audio formats * Simplify the handling of the different buffer types. * Check the sample rate. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -21,7 +21,7 @@ candle-onnx = { workspace = true, optional = true }
|
||||
csv = "1.3.0"
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
hf-hub = { workspace = true, features = ["tokio"] }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
@ -30,6 +30,7 @@ rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
symphonia = { version = "0.5.3", features = ["all"] }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
|
||||
[dev-dependencies]
|
||||
@ -43,7 +44,6 @@ rusttype = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
|
||||
|
@ -18,6 +18,8 @@ use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
mod pcm_decode;
|
||||
|
||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||
|
||||
pub enum Model {
|
||||
@ -535,17 +537,10 @@ fn main() -> Result<()> {
|
||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||
|
||||
let mut input = std::fs::File::open(input)?;
|
||||
let (header, data) = wav::read(&mut input)?;
|
||||
println!("loaded wav data: {header:?}");
|
||||
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||
let (pcm_data, sample_rate) = pcm_decode::pcm_decode(input)?;
|
||||
if sample_rate != m::SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("input file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||
}
|
||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||
.iter()
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
println!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
||||
let mel_len = mel.len();
|
||||
|
74
candle-examples/examples/whisper/pcm_decode.rs
Normal file
74
candle-examples/examples/whisper/pcm_decode.rs
Normal file
@ -0,0 +1,74 @@
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
|
||||
use symphonia::core::conv::FromSample;
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||
// Open the media source.
|
||||
let src = std::fs::File::open(path)?;
|
||||
|
||||
// Create the media source stream.
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
|
||||
// Create a probe hint using the file's extension. [Optional]
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
|
||||
// Use the default options for metadata and format readers.
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
|
||||
// Probe the media source.
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
// Get the instantiated format reader.
|
||||
let mut format = probed.format;
|
||||
|
||||
// Find the first audio track with a known (decodeable) codec.
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
|
||||
// Use the default options for the decoder.
|
||||
let dec_opts: DecoderOptions = Default::default();
|
||||
|
||||
// Create a decoder for the track.
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &dec_opts)
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
// The decode loop.
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
// Consume any new metadata that has been read since the last packet.
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
|
||||
// If the packet does not belong to the selected track, skip over it.
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
@ -23,7 +23,6 @@ serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde_plain = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
Reference in New Issue
Block a user