mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Refactor the whisper microphone example. (#2523)
* Refactor the whisper microphone example. * Tweak the whisper microphone example more.
This commit is contained in:
@ -65,7 +65,7 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
|
|||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
onnx = ["candle-onnx"]
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
microphone = ["cpal"]
|
microphone = ["cpal", "rubato"]
|
||||||
encodec = ["cpal", "symphonia", "rubato"]
|
encodec = ["cpal", "symphonia", "rubato"]
|
||||||
mimi = ["cpal", "symphonia", "rubato"]
|
mimi = ["cpal", "symphonia", "rubato"]
|
||||||
depth_anything_v2 = ["palette", "enterpolation"]
|
depth_anything_v2 = ["palette", "enterpolation"]
|
||||||
|
@ -10,7 +10,6 @@ use candle_nn::{ops::softmax, VarBuilder};
|
|||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use rand::{distributions::Distribution, SeedableRng};
|
||||||
use std::iter;
|
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod multilingual;
|
mod multilingual;
|
||||||
@ -18,7 +17,6 @@ mod multilingual;
|
|||||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||||
|
|
||||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
pub enum Model {
|
pub enum Model {
|
||||||
Normal(m::model::Whisper),
|
Normal(m::model::Whisper),
|
||||||
@ -479,6 +477,10 @@ struct Args {
|
|||||||
/// Print the full DecodingResult structure rather than just the text.
|
/// Print the full DecodingResult structure rather than just the text.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
|
|
||||||
|
/// The input device to use.
|
||||||
|
#[arg(long)]
|
||||||
|
device: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> Result<()> {
|
pub fn main() -> Result<()> {
|
||||||
@ -543,13 +545,12 @@ pub fn main() -> Result<()> {
|
|||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
||||||
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
|
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
|
||||||
};
|
};
|
||||||
let language_token = None;
|
let mut decoder = Decoder::new(
|
||||||
let mut dc = Decoder::new(
|
|
||||||
model,
|
model,
|
||||||
tokenizer.clone(),
|
tokenizer.clone(),
|
||||||
args.seed,
|
args.seed,
|
||||||
&device,
|
&device,
|
||||||
language_token,
|
/* language_token */ None,
|
||||||
args.task,
|
args.task,
|
||||||
args.timestamps,
|
args.timestamps,
|
||||||
args.verbose,
|
args.verbose,
|
||||||
@ -565,47 +566,69 @@ pub fn main() -> Result<()> {
|
|||||||
|
|
||||||
// Set up the input device and stream with the default input config.
|
// Set up the input device and stream with the default input config.
|
||||||
let host = cpal::default_host();
|
let host = cpal::default_host();
|
||||||
let _device = "default";
|
let audio_device = match args.device.as_ref() {
|
||||||
let _device = if _device == "default" {
|
None => host.default_input_device(),
|
||||||
host.default_input_device()
|
Some(device) => host
|
||||||
} else {
|
.input_devices()?
|
||||||
host.input_devices()?
|
.find(|x| x.name().map_or(false, |y| &y == device)),
|
||||||
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
|
|
||||||
}
|
}
|
||||||
.expect("failed to find input device");
|
.expect("failed to find the audio input device");
|
||||||
|
|
||||||
let _config = _device
|
let audio_config = audio_device
|
||||||
.default_input_config()
|
.default_input_config()
|
||||||
.expect("Failed to get default input config");
|
.expect("Failed to get default input config");
|
||||||
|
println!("audio config {audio_config:?}");
|
||||||
|
|
||||||
let channel_count = _config.channels() as usize;
|
let channel_count = audio_config.channels() as usize;
|
||||||
|
let in_sample_rate = audio_config.sample_rate().0 as usize;
|
||||||
let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
|
let resample_ratio = 16000. / in_sample_rate as f64;
|
||||||
let audio_ring_buffer_2 = audio_ring_buffer.clone();
|
let mut resampler = rubato::FastFixedIn::new(
|
||||||
|
resample_ratio,
|
||||||
std::thread::spawn(move || loop {
|
10.,
|
||||||
let data = record_audio(&_device, &_config, 300).unwrap();
|
rubato::PolynomialDegree::Septic,
|
||||||
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
|
1024,
|
||||||
let max_len = data.len() * 16;
|
1,
|
||||||
let data_len = data.len();
|
)?;
|
||||||
let len = audio_ring_buffer.lock().unwrap().len();
|
let (tx, rx) = std::sync::mpsc::channel();
|
||||||
if len > max_len {
|
let stream = audio_device.build_input_stream(
|
||||||
let mut data = audio_ring_buffer.lock().unwrap();
|
&audio_config.config(),
|
||||||
let new_data = data[data_len..].to_vec();
|
move |pcm: &[f32], _: &cpal::InputCallbackInfo| {
|
||||||
*data = new_data;
|
let pcm = pcm
|
||||||
}
|
.iter()
|
||||||
});
|
.step_by(channel_count)
|
||||||
|
.copied()
|
||||||
|
.collect::<Vec<f32>>();
|
||||||
|
if !pcm.is_empty() {
|
||||||
|
tx.send(pcm).unwrap()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
move |err| {
|
||||||
|
eprintln!("an error occurred on stream: {}", err);
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
stream.play()?;
|
||||||
|
|
||||||
// loop to process the audio data forever (until the user stops the program)
|
// loop to process the audio data forever (until the user stops the program)
|
||||||
println!("Transcribing audio...");
|
println!("transcribing audio...");
|
||||||
for (i, _) in iter::repeat(()).enumerate() {
|
let mut buffered_pcm = vec![];
|
||||||
std::thread::sleep(std::time::Duration::from_millis(1000));
|
let mut language_token_set = false;
|
||||||
let data = audio_ring_buffer_2.lock().unwrap().clone();
|
while let Ok(pcm) = rx.recv() {
|
||||||
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
|
use rubato::Resampler;
|
||||||
.iter()
|
|
||||||
.map(|v| *v as f32 / 32768.)
|
buffered_pcm.extend_from_slice(&pcm);
|
||||||
.collect();
|
if buffered_pcm.len() < 10 * in_sample_rate {
|
||||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
continue;
|
||||||
|
}
|
||||||
|
let mut resampled_pcm = vec![];
|
||||||
|
for buffered_pcm in buffered_pcm.chunks(1024) {
|
||||||
|
let pcm = resampler.process(&[&buffered_pcm], None)?;
|
||||||
|
resampled_pcm.extend_from_slice(&pcm[0])
|
||||||
|
}
|
||||||
|
let pcm = resampled_pcm;
|
||||||
|
println!("{} {}", buffered_pcm.len(), pcm.len());
|
||||||
|
buffered_pcm.clear();
|
||||||
|
let mel = audio::pcm_to_mel(&config, &pcm, &mel_filters);
|
||||||
let mel_len = mel.len();
|
let mel_len = mel.len();
|
||||||
let mel = Tensor::from_vec(
|
let mel = Tensor::from_vec(
|
||||||
mel,
|
mel,
|
||||||
@ -614,9 +637,13 @@ pub fn main() -> Result<()> {
|
|||||||
)?;
|
)?;
|
||||||
|
|
||||||
// on the first iteration, we detect the language and set the language token.
|
// on the first iteration, we detect the language and set the language token.
|
||||||
if i == 0 {
|
if !language_token_set {
|
||||||
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
|
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
|
||||||
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
|
(true, None) => Some(multilingual::detect_language(
|
||||||
|
decoder.model(),
|
||||||
|
&tokenizer,
|
||||||
|
&mel,
|
||||||
|
)?),
|
||||||
(false, None) => None,
|
(false, None) => None,
|
||||||
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
|
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
|
||||||
Ok(token_id) => Some(token_id),
|
Ok(token_id) => Some(token_id),
|
||||||
@ -627,47 +654,12 @@ pub fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
println!("language_token: {:?}", language_token);
|
println!("language_token: {:?}", language_token);
|
||||||
dc.set_language_token(language_token);
|
decoder.set_language_token(language_token);
|
||||||
|
language_token_set = true;
|
||||||
}
|
}
|
||||||
dc.run(
|
decoder.run(&mel, None)?;
|
||||||
&mel,
|
decoder.reset_kv_cache();
|
||||||
Some((
|
|
||||||
i as f64,
|
|
||||||
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
|
|
||||||
)),
|
|
||||||
)?;
|
|
||||||
dc.reset_kv_cache();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn record_audio(
|
|
||||||
device: &cpal::Device,
|
|
||||||
config: &cpal::SupportedStreamConfig,
|
|
||||||
milliseconds: u64,
|
|
||||||
) -> Result<Vec<i16>> {
|
|
||||||
let writer = Arc::new(Mutex::new(Vec::new()));
|
|
||||||
let writer_2 = writer.clone();
|
|
||||||
let stream = device.build_input_stream(
|
|
||||||
&config.config(),
|
|
||||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
|
||||||
let processed = data
|
|
||||||
.iter()
|
|
||||||
.map(|v| (v * 32768.0) as i16)
|
|
||||||
.collect::<Vec<i16>>();
|
|
||||||
writer_2.lock().unwrap().extend_from_slice(&processed);
|
|
||||||
},
|
|
||||||
move |err| {
|
|
||||||
eprintln!("an error occurred on stream: {}", err);
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
stream.play()?;
|
|
||||||
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
|
|
||||||
drop(stream);
|
|
||||||
let data = writer.lock().unwrap().clone();
|
|
||||||
let step = 3;
|
|
||||||
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
|
|
||||||
Ok(data)
|
|
||||||
}
|
|
||||||
|
Reference in New Issue
Block a user