From 6110ad8d4ff8272bdd10687eae4edee59a07b517 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 1 Oct 2024 00:24:17 +0200 Subject: [PATCH] Refactor the whisper microphone example. (#2523) * Refactor the whisper microphone example. * Tweak the whisper microphone example more. --- candle-examples/Cargo.toml | 2 +- .../examples/whisper-microphone/main.rs | 152 +++++++++--------- 2 files changed, 73 insertions(+), 81 deletions(-) diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 543c9666..2c96f87d 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -65,7 +65,7 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/ nccl = ["cuda", "cudarc/nccl", "dep:half"] onnx = ["candle-onnx"] metal = ["candle/metal", "candle-nn/metal"] -microphone = ["cpal"] +microphone = ["cpal", "rubato"] encodec = ["cpal", "symphonia", "rubato"] mimi = ["cpal", "symphonia", "rubato"] depth_anything_v2 = ["palette", "enterpolation"] diff --git a/candle-examples/examples/whisper-microphone/main.rs b/candle-examples/examples/whisper-microphone/main.rs index 9f7d5b82..44a64b05 100644 --- a/candle-examples/examples/whisper-microphone/main.rs +++ b/candle-examples/examples/whisper-microphone/main.rs @@ -10,7 +10,6 @@ use candle_nn::{ops::softmax, VarBuilder}; use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use rand::{distributions::Distribution, SeedableRng}; -use std::iter; use tokenizers::Tokenizer; mod multilingual; @@ -18,7 +17,6 @@ mod multilingual; use candle_transformers::models::whisper::{self as m, audio, Config}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; -use std::sync::{Arc, Mutex}; pub enum Model { Normal(m::model::Whisper), @@ -479,6 +477,10 @@ struct Args { /// Print the full DecodingResult structure rather than just the text. #[arg(long)] verbose: bool, + + /// The input device to use. + #[arg(long)] + device: Option, } pub fn main() -> Result<()> { @@ -543,13 +545,12 @@ pub fn main() -> Result<()> { unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? }; Model::Normal(m::model::Whisper::load(&vb, config.clone())?) }; - let language_token = None; - let mut dc = Decoder::new( + let mut decoder = Decoder::new( model, tokenizer.clone(), args.seed, &device, - language_token, + /* language_token */ None, args.task, args.timestamps, args.verbose, @@ -565,47 +566,69 @@ pub fn main() -> Result<()> { // Set up the input device and stream with the default input config. let host = cpal::default_host(); - let _device = "default"; - let _device = if _device == "default" { - host.default_input_device() - } else { - host.input_devices()? - .find(|x| x.name().map(|y| y == _device).unwrap_or(false)) + let audio_device = match args.device.as_ref() { + None => host.default_input_device(), + Some(device) => host + .input_devices()? + .find(|x| x.name().map_or(false, |y| &y == device)), } - .expect("failed to find input device"); + .expect("failed to find the audio input device"); - let _config = _device + let audio_config = audio_device .default_input_config() .expect("Failed to get default input config"); + println!("audio config {audio_config:?}"); - let channel_count = _config.channels() as usize; - - let audio_ring_buffer = Arc::new(Mutex::new(Vec::new())); - let audio_ring_buffer_2 = audio_ring_buffer.clone(); - - std::thread::spawn(move || loop { - let data = record_audio(&_device, &_config, 300).unwrap(); - audio_ring_buffer.lock().unwrap().extend_from_slice(&data); - let max_len = data.len() * 16; - let data_len = data.len(); - let len = audio_ring_buffer.lock().unwrap().len(); - if len > max_len { - let mut data = audio_ring_buffer.lock().unwrap(); - let new_data = data[data_len..].to_vec(); - *data = new_data; - } - }); + let channel_count = audio_config.channels() as usize; + let in_sample_rate = audio_config.sample_rate().0 as usize; + let resample_ratio = 16000. / in_sample_rate as f64; + let mut resampler = rubato::FastFixedIn::new( + resample_ratio, + 10., + rubato::PolynomialDegree::Septic, + 1024, + 1, + )?; + let (tx, rx) = std::sync::mpsc::channel(); + let stream = audio_device.build_input_stream( + &audio_config.config(), + move |pcm: &[f32], _: &cpal::InputCallbackInfo| { + let pcm = pcm + .iter() + .step_by(channel_count) + .copied() + .collect::>(); + if !pcm.is_empty() { + tx.send(pcm).unwrap() + } + }, + move |err| { + eprintln!("an error occurred on stream: {}", err); + }, + None, + )?; + stream.play()?; // loop to process the audio data forever (until the user stops the program) - println!("Transcribing audio..."); - for (i, _) in iter::repeat(()).enumerate() { - std::thread::sleep(std::time::Duration::from_millis(1000)); - let data = audio_ring_buffer_2.lock().unwrap().clone(); - let pcm_data: Vec<_> = data[..data.len() / channel_count as usize] - .iter() - .map(|v| *v as f32 / 32768.) - .collect(); - let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters); + println!("transcribing audio..."); + let mut buffered_pcm = vec![]; + let mut language_token_set = false; + while let Ok(pcm) = rx.recv() { + use rubato::Resampler; + + buffered_pcm.extend_from_slice(&pcm); + if buffered_pcm.len() < 10 * in_sample_rate { + continue; + } + let mut resampled_pcm = vec![]; + for buffered_pcm in buffered_pcm.chunks(1024) { + let pcm = resampler.process(&[&buffered_pcm], None)?; + resampled_pcm.extend_from_slice(&pcm[0]) + } + let pcm = resampled_pcm; + println!("{} {}", buffered_pcm.len(), pcm.len()); + buffered_pcm.clear(); + let mel = audio::pcm_to_mel(&config, &pcm, &mel_filters); let mel_len = mel.len(); let mel = Tensor::from_vec( mel, @@ -614,9 +637,13 @@ pub fn main() -> Result<()> { )?; // on the first iteration, we detect the language and set the language token. - if i == 0 { + if !language_token_set { let language_token = match (args.model.is_multilingual(), args.language.clone()) { - (true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?), + (true, None) => Some(multilingual::detect_language( + decoder.model(), + &tokenizer, + &mel, + )?), (false, None) => None, (true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) { Ok(token_id) => Some(token_id), @@ -627,47 +654,12 @@ pub fn main() -> Result<()> { } }; println!("language_token: {:?}", language_token); - dc.set_language_token(language_token); + decoder.set_language_token(language_token); + language_token_set = true; } - dc.run( - &mel, - Some(( - i as f64, - i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64, - )), - )?; - dc.reset_kv_cache(); + decoder.run(&mel, None)?; + decoder.reset_kv_cache(); } Ok(()) } - -fn record_audio( - device: &cpal::Device, - config: &cpal::SupportedStreamConfig, - milliseconds: u64, -) -> Result> { - let writer = Arc::new(Mutex::new(Vec::new())); - let writer_2 = writer.clone(); - let stream = device.build_input_stream( - &config.config(), - move |data: &[f32], _: &cpal::InputCallbackInfo| { - let processed = data - .iter() - .map(|v| (v * 32768.0) as i16) - .collect::>(); - writer_2.lock().unwrap().extend_from_slice(&processed); - }, - move |err| { - eprintln!("an error occurred on stream: {}", err); - }, - None, - )?; - stream.play()?; - std::thread::sleep(std::time::Duration::from_millis(milliseconds)); - drop(stream); - let data = writer.lock().unwrap().clone(); - let step = 3; - let data: Vec = data.iter().step_by(step).copied().collect(); - Ok(data) -}