mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Supports more audio formats (#1628)
* Supports more audio formats * Simplify the handling of the different buffer types. * Check the sample rate. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -21,7 +21,7 @@ candle-onnx = { workspace = true, optional = true }
|
|||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
hf-hub = { workspace = true, features=["tokio"]}
|
hf-hub = { workspace = true, features = ["tokio"] }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
@ -30,6 +30,7 @@ rayon = { workspace = true }
|
|||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
symphonia = { version = "0.5.3", features = ["all"] }
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
@ -43,7 +44,6 @@ rusttype = { workspace = true }
|
|||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-chrome = { workspace = true }
|
tracing-chrome = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
wav = { workspace = true }
|
|
||||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
tokio = "1.29.1"
|
tokio = "1.29.1"
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ use rand::{distributions::Distribution, SeedableRng};
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod multilingual;
|
mod multilingual;
|
||||||
|
mod pcm_decode;
|
||||||
|
|
||||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||||
|
|
||||||
pub enum Model {
|
pub enum Model {
|
||||||
@ -535,17 +537,10 @@ fn main() -> Result<()> {
|
|||||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||||
|
|
||||||
let mut input = std::fs::File::open(input)?;
|
let (pcm_data, sample_rate) = pcm_decode::pcm_decode(input)?;
|
||||||
let (header, data) = wav::read(&mut input)?;
|
if sample_rate != m::SAMPLE_RATE as u32 {
|
||||||
println!("loaded wav data: {header:?}");
|
anyhow::bail!("input file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||||
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
|
||||||
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
|
|
||||||
}
|
}
|
||||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
|
||||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
|
||||||
.iter()
|
|
||||||
.map(|v| *v as f32 / 32768.)
|
|
||||||
.collect();
|
|
||||||
println!("pcm data loaded {}", pcm_data.len());
|
println!("pcm data loaded {}", pcm_data.len());
|
||||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
||||||
let mel_len = mel.len();
|
let mel_len = mel.len();
|
||||||
|
74
candle-examples/examples/whisper/pcm_decode.rs
Normal file
74
candle-examples/examples/whisper/pcm_decode.rs
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||||
|
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
|
||||||
|
use symphonia::core::conv::FromSample;
|
||||||
|
|
||||||
|
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||||
|
where
|
||||||
|
T: symphonia::core::sample::Sample,
|
||||||
|
f32: symphonia::core::conv::FromSample<T>,
|
||||||
|
{
|
||||||
|
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||||
|
// Open the media source.
|
||||||
|
let src = std::fs::File::open(path)?;
|
||||||
|
|
||||||
|
// Create the media source stream.
|
||||||
|
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||||
|
|
||||||
|
// Create a probe hint using the file's extension. [Optional]
|
||||||
|
let hint = symphonia::core::probe::Hint::new();
|
||||||
|
|
||||||
|
// Use the default options for metadata and format readers.
|
||||||
|
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||||
|
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||||
|
|
||||||
|
// Probe the media source.
|
||||||
|
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||||
|
// Get the instantiated format reader.
|
||||||
|
let mut format = probed.format;
|
||||||
|
|
||||||
|
// Find the first audio track with a known (decodeable) codec.
|
||||||
|
let track = format
|
||||||
|
.tracks()
|
||||||
|
.iter()
|
||||||
|
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
|
||||||
|
.expect("no supported audio tracks");
|
||||||
|
|
||||||
|
// Use the default options for the decoder.
|
||||||
|
let dec_opts: DecoderOptions = Default::default();
|
||||||
|
|
||||||
|
// Create a decoder for the track.
|
||||||
|
let mut decoder = symphonia::default::get_codecs()
|
||||||
|
.make(&track.codec_params, &dec_opts)
|
||||||
|
.expect("unsupported codec");
|
||||||
|
let track_id = track.id;
|
||||||
|
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||||
|
let mut pcm_data = Vec::new();
|
||||||
|
// The decode loop.
|
||||||
|
while let Ok(packet) = format.next_packet() {
|
||||||
|
// Consume any new metadata that has been read since the last packet.
|
||||||
|
while !format.metadata().is_latest() {
|
||||||
|
format.metadata().pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the packet does not belong to the selected track, skip over it.
|
||||||
|
if packet.track_id() != track_id {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match decoder.decode(&packet)? {
|
||||||
|
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||||
|
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((pcm_data, sample_rate))
|
||||||
|
}
|
@ -23,7 +23,6 @@ serde = { workspace = true }
|
|||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
serde_plain = { workspace = true }
|
serde_plain = { workspace = true }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
wav = { workspace = true }
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
Reference in New Issue
Block a user