Refactor the whisper microphone example. (#2523)

* Refactor the whisper microphone example.

* Tweak the whisper microphone example more.
This commit is contained in:
Laurent Mazare
2024-10-01 00:24:17 +02:00
committed by GitHub
parent aa35bf2ff5
commit 6110ad8d4f
2 changed files with 73 additions and 81 deletions

View File

@ -65,7 +65,7 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
nccl = ["cuda", "cudarc/nccl", "dep:half"] nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"] onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"] metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"] microphone = ["cpal", "rubato"]
encodec = ["cpal", "symphonia", "rubato"] encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"] mimi = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"] depth_anything_v2 = ["palette", "enterpolation"]

View File

@ -10,7 +10,6 @@ use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng}; use rand::{distributions::Distribution, SeedableRng};
use std::iter;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
mod multilingual; mod multilingual;
@ -18,7 +17,6 @@ mod multilingual;
use candle_transformers::models::whisper::{self as m, audio, Config}; use candle_transformers::models::whisper::{self as m, audio, Config};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use std::sync::{Arc, Mutex};
pub enum Model { pub enum Model {
Normal(m::model::Whisper), Normal(m::model::Whisper),
@ -479,6 +477,10 @@ struct Args {
/// Print the full DecodingResult structure rather than just the text. /// Print the full DecodingResult structure rather than just the text.
#[arg(long)] #[arg(long)]
verbose: bool, verbose: bool,
/// The input device to use.
#[arg(long)]
device: Option<String>,
} }
pub fn main() -> Result<()> { pub fn main() -> Result<()> {
@ -543,13 +545,12 @@ pub fn main() -> Result<()> {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? }; unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
Model::Normal(m::model::Whisper::load(&vb, config.clone())?) Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
}; };
let language_token = None; let mut decoder = Decoder::new(
let mut dc = Decoder::new(
model, model,
tokenizer.clone(), tokenizer.clone(),
args.seed, args.seed,
&device, &device,
language_token, /* language_token */ None,
args.task, args.task,
args.timestamps, args.timestamps,
args.verbose, args.verbose,
@ -565,47 +566,69 @@ pub fn main() -> Result<()> {
// Set up the input device and stream with the default input config. // Set up the input device and stream with the default input config.
let host = cpal::default_host(); let host = cpal::default_host();
let _device = "default"; let audio_device = match args.device.as_ref() {
let _device = if _device == "default" { None => host.default_input_device(),
host.default_input_device() Some(device) => host
} else { .input_devices()?
host.input_devices()? .find(|x| x.name().map_or(false, |y| &y == device)),
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
} }
.expect("failed to find input device"); .expect("failed to find the audio input device");
let _config = _device let audio_config = audio_device
.default_input_config() .default_input_config()
.expect("Failed to get default input config"); .expect("Failed to get default input config");
println!("audio config {audio_config:?}");
let channel_count = _config.channels() as usize; let channel_count = audio_config.channels() as usize;
let in_sample_rate = audio_config.sample_rate().0 as usize;
let audio_ring_buffer = Arc::new(Mutex::new(Vec::new())); let resample_ratio = 16000. / in_sample_rate as f64;
let audio_ring_buffer_2 = audio_ring_buffer.clone(); let mut resampler = rubato::FastFixedIn::new(
resample_ratio,
std::thread::spawn(move || loop { 10.,
let data = record_audio(&_device, &_config, 300).unwrap(); rubato::PolynomialDegree::Septic,
audio_ring_buffer.lock().unwrap().extend_from_slice(&data); 1024,
let max_len = data.len() * 16; 1,
let data_len = data.len(); )?;
let len = audio_ring_buffer.lock().unwrap().len(); let (tx, rx) = std::sync::mpsc::channel();
if len > max_len { let stream = audio_device.build_input_stream(
let mut data = audio_ring_buffer.lock().unwrap(); &audio_config.config(),
let new_data = data[data_len..].to_vec(); move |pcm: &[f32], _: &cpal::InputCallbackInfo| {
*data = new_data; let pcm = pcm
.iter()
.step_by(channel_count)
.copied()
.collect::<Vec<f32>>();
if !pcm.is_empty() {
tx.send(pcm).unwrap()
} }
}); },
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;
// loop to process the audio data forever (until the user stops the program) // loop to process the audio data forever (until the user stops the program)
println!("Transcribing audio..."); println!("transcribing audio...");
for (i, _) in iter::repeat(()).enumerate() { let mut buffered_pcm = vec![];
std::thread::sleep(std::time::Duration::from_millis(1000)); let mut language_token_set = false;
let data = audio_ring_buffer_2.lock().unwrap().clone(); while let Ok(pcm) = rx.recv() {
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize] use rubato::Resampler;
.iter()
.map(|v| *v as f32 / 32768.) buffered_pcm.extend_from_slice(&pcm);
.collect(); if buffered_pcm.len() < 10 * in_sample_rate {
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters); continue;
}
let mut resampled_pcm = vec![];
for buffered_pcm in buffered_pcm.chunks(1024) {
let pcm = resampler.process(&[&buffered_pcm], None)?;
resampled_pcm.extend_from_slice(&pcm[0])
}
let pcm = resampled_pcm;
println!("{} {}", buffered_pcm.len(), pcm.len());
buffered_pcm.clear();
let mel = audio::pcm_to_mel(&config, &pcm, &mel_filters);
let mel_len = mel.len(); let mel_len = mel.len();
let mel = Tensor::from_vec( let mel = Tensor::from_vec(
mel, mel,
@ -614,9 +637,13 @@ pub fn main() -> Result<()> {
)?; )?;
// on the first iteration, we detect the language and set the language token. // on the first iteration, we detect the language and set the language token.
if i == 0 { if !language_token_set {
let language_token = match (args.model.is_multilingual(), args.language.clone()) { let language_token = match (args.model.is_multilingual(), args.language.clone()) {
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?), (true, None) => Some(multilingual::detect_language(
decoder.model(),
&tokenizer,
&mel,
)?),
(false, None) => None, (false, None) => None,
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) { (true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
Ok(token_id) => Some(token_id), Ok(token_id) => Some(token_id),
@ -627,47 +654,12 @@ pub fn main() -> Result<()> {
} }
}; };
println!("language_token: {:?}", language_token); println!("language_token: {:?}", language_token);
dc.set_language_token(language_token); decoder.set_language_token(language_token);
language_token_set = true;
} }
dc.run( decoder.run(&mel, None)?;
&mel, decoder.reset_kv_cache();
Some((
i as f64,
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
)),
)?;
dc.reset_kv_cache();
} }
Ok(()) Ok(())
} }
fn record_audio(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
milliseconds: u64,
) -> Result<Vec<i16>> {
let writer = Arc::new(Mutex::new(Vec::new()));
let writer_2 = writer.clone();
let stream = device.build_input_stream(
&config.config(),
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let processed = data
.iter()
.map(|v| (v * 32768.0) as i16)
.collect::<Vec<i16>>();
writer_2.lock().unwrap().extend_from_slice(&processed);
},
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
drop(stream);
let data = writer.lock().unwrap().clone();
let step = 3;
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
Ok(data)
}