mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Read wav files.
This commit is contained in:
@ -21,6 +21,7 @@ clap = { version = "4.2.4", features = ["derive"] }
|
|||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
||||||
tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] }
|
tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] }
|
||||||
|
wav = "1.0.0"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["cuda"]
|
default = ["cuda"]
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
// Audio processing code, adapted from whisper.cpp
|
// Audio processing code, adapted from whisper.cpp
|
||||||
// https://github.com/ggerganov/whisper.cpp
|
// https://github.com/ggerganov/whisper.cpp
|
||||||
|
|
||||||
const WHISPER_SAMPLE_RATE: usize = 16000;
|
pub const WHISPER_SAMPLE_RATE: usize = 16000;
|
||||||
const WHISPER_N_FFT: usize = 400;
|
pub const WHISPER_N_FFT: usize = 400;
|
||||||
const WHISPER_N_MEL: usize = 80;
|
pub const WHISPER_N_MEL: usize = 80;
|
||||||
const WHISPER_HOP_LENGTH: usize = 160;
|
pub const WHISPER_HOP_LENGTH: usize = 160;
|
||||||
const WHISPER_CHUNK_SIZE: usize = 30;
|
pub const WHISPER_CHUNK_SIZE: usize = 30;
|
||||||
|
|
||||||
trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
|
pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
|
||||||
|
|
||||||
|
impl Float for f32 {}
|
||||||
|
impl Float for f64 {}
|
||||||
|
|
||||||
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
|
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
|
||||||
fn fft<T: Float>(inp: &[T]) -> Vec<T> {
|
fn fft<T: Float>(inp: &[T]) -> Vec<T> {
|
||||||
@ -203,7 +206,7 @@ fn log_mel_spectrogram_<T: Float>(
|
|||||||
mel
|
mel
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pcm_to_mel<T: Float>(samples: &[T], filters: &[T]) -> anyhow::Result<Vec<T>> {
|
pub fn pcm_to_mel<T: Float>(samples: &[T], filters: &[T]) -> anyhow::Result<Vec<T>> {
|
||||||
if filters.len() != WHISPER_N_MEL * WHISPER_N_FFT {
|
if filters.len() != WHISPER_N_MEL * WHISPER_N_FFT {
|
||||||
anyhow::bail!(
|
anyhow::bail!(
|
||||||
"unexpected filter length {} (n_mel: {}, n_fft: {})",
|
"unexpected filter length {} (n_mel: {}, n_fft: {})",
|
||||||
|
@ -40,33 +40,6 @@ const EOT_TOKEN: u32 = 50256;
|
|||||||
const NO_SPEECH_TOKEN: u32 = 50361;
|
const NO_SPEECH_TOKEN: u32 = 50361;
|
||||||
const NO_TIMESTAMP_TOKEN: u32 = 50362;
|
const NO_TIMESTAMP_TOKEN: u32 = 50362;
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weights: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
input: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_config: String,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
#[arg(
|
|
||||||
long,
|
|
||||||
default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
|
|
||||||
)]
|
|
||||||
filters: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct DecodingResult {
|
struct DecodingResult {
|
||||||
tokens: Vec<u32>,
|
tokens: Vec<u32>,
|
||||||
@ -184,6 +157,35 @@ impl Decode {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weights: String,
|
||||||
|
|
||||||
|
/// The input to be processed, in wav formats.
|
||||||
|
#[arg(long)]
|
||||||
|
input: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_config: String,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The mel filters in safetensors format.
|
||||||
|
#[arg(
|
||||||
|
long,
|
||||||
|
default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
|
||||||
|
)]
|
||||||
|
filters: String,
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let device = if args.cpu {
|
let device = if args.cpu {
|
||||||
@ -195,14 +197,28 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
||||||
|
|
||||||
let filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
|
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
|
||||||
let filters = filters.deserialize()?;
|
let mel_filters = mel_filters.deserialize()?;
|
||||||
let filters = filters.tensor("mel_80", &device)?;
|
let mel_filters = mel_filters.tensor("mel_80", &device)?;
|
||||||
println!("loaded mel filters {:?}", filters.shape());
|
println!("loaded mel filters {:?}", mel_filters.shape());
|
||||||
|
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
|
||||||
let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
|
let mut input = std::fs::File::open(args.input)?;
|
||||||
let input = input.deserialize()?;
|
let (header, data) = wav::read(&mut input)?;
|
||||||
let mel = input.tensor("mel", &device)?;
|
println!("loaded wav data: {header:?}");
|
||||||
|
if header.sampling_rate != audio::WHISPER_SAMPLE_RATE as u32 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"wav file must have a {} sampling rate",
|
||||||
|
audio::WHISPER_SAMPLE_RATE
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||||
|
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||||
|
.iter()
|
||||||
|
.map(|v| *v as f32 / 32768.)
|
||||||
|
.collect();
|
||||||
|
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
|
||||||
|
let mel = Tensor::new(&mel[..], &device)?;
|
||||||
println!("loaded mel: {:?}", mel.dims());
|
println!("loaded mel: {:?}", mel.dims());
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||||
|
Reference in New Issue
Block a user