diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index af5f791a..e07af856 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -32,6 +32,7 @@ serde = { workspace = true } serde_json = { workspace = true } symphonia = { version = "0.5.3", features = ["all"] } tokenizers = { workspace = true, features = ["onig"] } +cpal= { version = "0.15.2", optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -61,6 +62,7 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/ nccl = ["cuda", "cudarc/nccl", "dep:half"] onnx = ["candle-onnx"] metal = ["candle/metal", "candle-nn/metal"] +microphone = ["cpal"] [[example]] name = "llama_multiprocess" @@ -77,3 +79,7 @@ required-features = ["onnx"] [[example]] name = "onnx_basics" required-features = ["onnx"] + +[[example]] +name = "whisper-microphone" +required-features = ["microphone"] diff --git a/candle-examples/examples/whisper-microphone/main.rs b/candle-examples/examples/whisper-microphone/main.rs new file mode 100644 index 00000000..9f7d5b82 --- /dev/null +++ b/candle-examples/examples/whisper-microphone/main.rs @@ -0,0 +1,673 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::{Error as E, Result}; +use candle::{Device, IndexOp, Tensor}; +use candle_nn::{ops::softmax, VarBuilder}; +use clap::{Parser, ValueEnum}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use rand::{distributions::Distribution, SeedableRng}; +use std::iter; +use tokenizers::Tokenizer; + +mod multilingual; + +use candle_transformers::models::whisper::{self as m, audio, Config}; + +use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; +use std::sync::{Arc, Mutex}; + +pub enum Model { + Normal(m::model::Whisper), + Quantized(m::quantized_model::Whisper), +} + +// Maybe we should use some traits rather than doing the dispatch for all these. +impl Model { + pub fn config(&self) -> &Config { + match self { + Self::Normal(m) => &m.config, + Self::Quantized(m) => &m.config, + } + } + + pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result { + match self { + Self::Normal(m) => m.encoder.forward(x, flush), + Self::Quantized(m) => m.encoder.forward(x, flush), + } + } + + pub fn decoder_forward( + &mut self, + x: &Tensor, + xa: &Tensor, + flush: bool, + ) -> candle::Result { + match self { + Self::Normal(m) => m.decoder.forward(x, xa, flush), + Self::Quantized(m) => m.decoder.forward(x, xa, flush), + } + } + + pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result { + match self { + Self::Normal(m) => m.decoder.final_linear(x), + Self::Quantized(m) => m.decoder.final_linear(x), + } + } +} + +#[allow(dead_code)] +#[derive(Debug, Clone)] +struct DecodingResult { + tokens: Vec, + text: String, + avg_logprob: f64, + no_speech_prob: f64, + temperature: f64, + compression_ratio: f64, +} + +#[allow(dead_code)] +#[derive(Debug, Clone)] +struct Segment { + start: f64, + duration: f64, + dr: DecodingResult, +} + +struct Decoder { + model: Model, + rng: rand::rngs::StdRng, + task: Option, + timestamps: bool, + verbose: bool, + tokenizer: Tokenizer, + suppress_tokens: Tensor, + sot_token: u32, + transcribe_token: u32, + translate_token: u32, + eot_token: u32, + no_speech_token: u32, + no_timestamps_token: u32, + language_token: Option, +} + +impl Decoder { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + device: &Device, + language_token: Option, + task: Option, + timestamps: bool, + verbose: bool, + ) -> Result { + let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?; + // Suppress the notimestamps token when in timestamps mode. + // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452 + let suppress_tokens: Vec = (0..model.config().vocab_size as u32) + .map(|i| { + if model.config().suppress_tokens.contains(&i) + || timestamps && i == no_timestamps_token + { + f32::NEG_INFINITY + } else { + 0f32 + } + }) + .collect(); + let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; + let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?; + let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?; + let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?; + let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?; + let no_speech_token = m::NO_SPEECH_TOKENS + .iter() + .find_map(|token| token_id(&tokenizer, token).ok()); + let no_speech_token = match no_speech_token { + None => anyhow::bail!("unable to find any non-speech token"), + Some(n) => n, + }; + Ok(Self { + model, + rng: rand::rngs::StdRng::seed_from_u64(seed), + tokenizer, + task, + timestamps, + verbose, + suppress_tokens, + sot_token, + transcribe_token, + translate_token, + eot_token, + no_speech_token, + language_token, + no_timestamps_token, + }) + } + + fn decode(&mut self, mel: &Tensor, t: f64) -> Result { + let model = &mut self.model; + let audio_features = model.encoder_forward(mel, true)?; + if self.verbose { + println!("audio features: {:?}", audio_features.dims()); + } + let sample_len = model.config().max_target_positions / 2; + let mut sum_logprob = 0f64; + let mut no_speech_prob = f64::NAN; + let mut tokens = vec![self.sot_token]; + if let Some(language_token) = self.language_token { + tokens.push(language_token); + } + match self.task { + None | Some(Task::Transcribe) => tokens.push(self.transcribe_token), + Some(Task::Translate) => tokens.push(self.translate_token), + } + if !self.timestamps { + tokens.push(self.no_timestamps_token); + } + for i in 0..sample_len { + let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; + + // The model expects a batch dim but this inference loop does not handle + // it so we add it at this point. + let tokens_t = tokens_t.unsqueeze(0)?; + let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?; + + // Extract the no speech probability on the first iteration by looking at the first + // token logits and the probability for the according token. + if i == 0 { + let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; + no_speech_prob = softmax(&logits, 0)? + .i(self.no_speech_token as usize)? + .to_scalar::()? as f64; + } + + let (_, seq_len, _) = ys.dims3()?; + let logits = model + .decoder_final_linear(&ys.i((..1, seq_len - 1..))?)? + .i(0)? + .i(0)?; + // TODO: Besides suppress tokens, we should apply the heuristics from + // ApplyTimestampRules, i.e.: + // - Timestamps come in pairs, except before EOT. + // - Timestamps should be non-decreasing. + // - If the sum of the probabilities of timestamps is higher than any other tokens, + // only consider timestamps when sampling. + // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439 + let logits = logits.broadcast_add(&self.suppress_tokens)?; + let next_token = if t > 0f64 { + let prs = softmax(&(&logits / t)?, 0)?; + let logits_v: Vec = prs.to_vec1()?; + let distr = rand::distributions::WeightedIndex::new(&logits_v)?; + distr.sample(&mut self.rng) as u32 + } else { + let logits_v: Vec = logits.to_vec1()?; + logits_v + .iter() + .enumerate() + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap() + }; + tokens.push(next_token); + let prob = softmax(&logits, candle::D::Minus1)? + .i(next_token as usize)? + .to_scalar::()? as f64; + if next_token == self.eot_token || tokens.len() > model.config().max_target_positions { + break; + } + sum_logprob += prob.ln(); + } + let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?; + let avg_logprob = sum_logprob / tokens.len() as f64; + + Ok(DecodingResult { + tokens, + text, + avg_logprob, + no_speech_prob, + temperature: t, + compression_ratio: f64::NAN, + }) + } + + fn decode_with_fallback(&mut self, segment: &Tensor) -> Result { + for (i, &t) in m::TEMPERATURES.iter().enumerate() { + let dr: Result = self.decode(segment, t); + if i == m::TEMPERATURES.len() - 1 { + return dr; + } + // On errors, we try again with a different temperature. + match dr { + Ok(dr) => { + let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD + || dr.avg_logprob < m::LOGPROB_THRESHOLD; + if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD { + return Ok(dr); + } + } + Err(err) => { + println!("Error running at {t}: {err}") + } + } + } + unreachable!() + } + + fn run(&mut self, mel: &Tensor, times: Option<(f64, f64)>) -> Result> { + let (_, _, content_frames) = mel.dims3()?; + let mut seek = 0; + let mut segments = vec![]; + while seek < content_frames { + let start = std::time::Instant::now(); + let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; + let segment_size = usize::min(content_frames - seek, m::N_FRAMES); + let mel_segment = mel.narrow(2, seek, segment_size)?; + let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; + let dr = self.decode_with_fallback(&mel_segment)?; + seek += segment_size; + if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD { + println!("no speech detected, skipping {seek} {dr:?}"); + continue; + } + let segment = Segment { + start: time_offset, + duration: segment_duration, + dr, + }; + if self.timestamps { + println!( + "{:.1}s -- {:.1}s", + segment.start, + segment.start + segment.duration, + ); + let mut tokens_to_decode = vec![]; + let mut prev_timestamp_s = 0f32; + for &token in segment.dr.tokens.iter() { + if token == self.sot_token || token == self.eot_token { + continue; + } + // The no_timestamp_token is the last before the timestamp ones. + if token > self.no_timestamps_token { + let timestamp_s = (token - self.no_timestamps_token + 1) as f32 / 50.; + if !tokens_to_decode.is_empty() { + let text = self + .tokenizer + .decode(&tokens_to_decode, true) + .map_err(E::msg)?; + println!(" {:.1}s-{:.1}s: {}", prev_timestamp_s, timestamp_s, text); + tokens_to_decode.clear() + } + prev_timestamp_s = timestamp_s; + } else { + tokens_to_decode.push(token) + } + } + if !tokens_to_decode.is_empty() { + let text = self + .tokenizer + .decode(&tokens_to_decode, true) + .map_err(E::msg)?; + if !text.is_empty() { + println!(" {:.1}s-...: {}", prev_timestamp_s, text); + } + tokens_to_decode.clear() + } + } else { + match times { + Some((start, end)) => { + println!("{:.1}s -- {:.1}s: {}", start, end, segment.dr.text) + } + None => { + println!( + "{:.1}s -- {:.1}s: {}", + segment.start, + segment.start + segment.duration, + segment.dr.text, + ) + } + } + } + if self.verbose { + println!("{seek}: {segment:?}, in {:?}", start.elapsed()); + } + segments.push(segment) + } + Ok(segments) + } + + fn set_language_token(&mut self, language_token: Option) { + self.language_token = language_token; + } + + #[allow(dead_code)] + fn reset_kv_cache(&mut self) { + match &mut self.model { + Model::Normal(m) => m.reset_kv_cache(), + Model::Quantized(m) => m.reset_kv_cache(), + } + } + + fn model(&mut self) -> &mut Model { + &mut self.model + } +} + +pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result { + match tokenizer.token_to_id(token) { + None => candle::bail!("no token-id for {token}"), + Some(id) => Ok(id), + } +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Task { + Transcribe, + Translate, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)] +enum WhichModel { + Tiny, + #[value(name = "tiny.en")] + TinyEn, + Base, + #[value(name = "base.en")] + BaseEn, + Small, + #[value(name = "small.en")] + SmallEn, + Medium, + #[value(name = "medium.en")] + MediumEn, + Large, + LargeV2, + LargeV3, + #[value(name = "distil-medium.en")] + DistilMediumEn, + #[value(name = "distil-large-v2")] + DistilLargeV2, +} + +impl WhichModel { + fn is_multilingual(&self) -> bool { + match self { + Self::Tiny + | Self::Base + | Self::Small + | Self::Medium + | Self::Large + | Self::LargeV2 + | Self::LargeV3 + | Self::DistilLargeV2 => true, + Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => { + false + } + } + } + + fn model_and_revision(&self) -> (&'static str, &'static str) { + match self { + Self::Tiny => ("openai/whisper-tiny", "main"), + Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"), + Self::Base => ("openai/whisper-base", "refs/pr/22"), + Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"), + Self::Small => ("openai/whisper-small", "main"), + Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"), + Self::Medium => ("openai/whisper-medium", "main"), + Self::MediumEn => ("openai/whisper-medium.en", "main"), + Self::Large => ("openai/whisper-large", "refs/pr/36"), + Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"), + Self::LargeV3 => ("openai/whisper-large-v3", "main"), + Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"), + Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"), + } + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + model_id: Option, + + /// The model to use, check out available models: + /// https://huggingface.co/models?search=whisper + #[arg(long)] + revision: Option, + + /// The model to be used, can be tiny, small, medium. + #[arg(long, default_value = "tiny.en")] + model: WhichModel, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + quantized: bool, + + /// Language. + #[arg(long)] + language: Option, + + /// Task, when no task is specified, the input tokens contain only the sot token which can + /// improve things when in no-timestamp mode. + #[arg(long)] + task: Option, + + /// Timestamps mode, this is not fully implemented yet. + #[arg(long)] + timestamps: bool, + + /// Print the full DecodingResult structure rather than just the text. + #[arg(long)] + verbose: bool, +} + +pub fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + let device = candle_examples::device(args.cpu)?; + let (default_model, default_revision) = if args.quantized { + ("lmz/candle-whisper", "main") + } else { + args.model.model_and_revision() + }; + let default_model = default_model.to_string(); + let default_revision = default_revision.to_string(); + let (model_id, revision) = match (args.model_id, args.revision) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let (config_filename, tokenizer_filename, weights_filename) = { + let api = Api::new()?; + let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + let (config, tokenizer, model) = if args.quantized { + let ext = match args.model { + WhichModel::TinyEn => "tiny-en", + WhichModel::Tiny => "tiny", + _ => unimplemented!("no quantized support for {:?}", args.model), + }; + ( + repo.get(&format!("config-{ext}.json"))?, + repo.get(&format!("tokenizer-{ext}.json"))?, + repo.get(&format!("model-{ext}-q80.gguf"))?, + ) + } else { + let config = repo.get("config.json")?; + let tokenizer = repo.get("tokenizer.json")?; + let model = repo.get("model.safetensors")?; + (config, tokenizer, model) + }; + (config, tokenizer, model) + }; + let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let model = if args.quantized { + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &weights_filename, + &device, + )?; + Model::Quantized(m::quantized_model::Whisper::load(&vb, config.clone())?) + } else { + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? }; + Model::Normal(m::model::Whisper::load(&vb, config.clone())?) + }; + let language_token = None; + let mut dc = Decoder::new( + model, + tokenizer.clone(), + args.seed, + &device, + language_token, + args.task, + args.timestamps, + args.verbose, + )?; + + let mel_bytes = match config.num_mel_bins { + 80 => include_bytes!("../whisper/melfilters.bytes").as_slice(), + 128 => include_bytes!("../whisper/melfilters128.bytes").as_slice(), + nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"), + }; + let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; + ::read_f32_into(mel_bytes, &mut mel_filters); + + // Set up the input device and stream with the default input config. + let host = cpal::default_host(); + let _device = "default"; + let _device = if _device == "default" { + host.default_input_device() + } else { + host.input_devices()? + .find(|x| x.name().map(|y| y == _device).unwrap_or(false)) + } + .expect("failed to find input device"); + + let _config = _device + .default_input_config() + .expect("Failed to get default input config"); + + let channel_count = _config.channels() as usize; + + let audio_ring_buffer = Arc::new(Mutex::new(Vec::new())); + let audio_ring_buffer_2 = audio_ring_buffer.clone(); + + std::thread::spawn(move || loop { + let data = record_audio(&_device, &_config, 300).unwrap(); + audio_ring_buffer.lock().unwrap().extend_from_slice(&data); + let max_len = data.len() * 16; + let data_len = data.len(); + let len = audio_ring_buffer.lock().unwrap().len(); + if len > max_len { + let mut data = audio_ring_buffer.lock().unwrap(); + let new_data = data[data_len..].to_vec(); + *data = new_data; + } + }); + + // loop to process the audio data forever (until the user stops the program) + println!("Transcribing audio..."); + for (i, _) in iter::repeat(()).enumerate() { + std::thread::sleep(std::time::Duration::from_millis(1000)); + let data = audio_ring_buffer_2.lock().unwrap().clone(); + let pcm_data: Vec<_> = data[..data.len() / channel_count as usize] + .iter() + .map(|v| *v as f32 / 32768.) + .collect(); + let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters); + let mel_len = mel.len(); + let mel = Tensor::from_vec( + mel, + (1, config.num_mel_bins, mel_len / config.num_mel_bins), + &device, + )?; + + // on the first iteration, we detect the language and set the language token. + if i == 0 { + let language_token = match (args.model.is_multilingual(), args.language.clone()) { + (true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?), + (false, None) => None, + (true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) { + Ok(token_id) => Some(token_id), + Err(_) => anyhow::bail!("language {language} is not supported"), + }, + (false, Some(_)) => { + anyhow::bail!("a language cannot be set for non-multilingual models") + } + }; + println!("language_token: {:?}", language_token); + dc.set_language_token(language_token); + } + dc.run( + &mel, + Some(( + i as f64, + i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64, + )), + )?; + dc.reset_kv_cache(); + } + + Ok(()) +} + +fn record_audio( + device: &cpal::Device, + config: &cpal::SupportedStreamConfig, + milliseconds: u64, +) -> Result> { + let writer = Arc::new(Mutex::new(Vec::new())); + let writer_2 = writer.clone(); + let stream = device.build_input_stream( + &config.config(), + move |data: &[f32], _: &cpal::InputCallbackInfo| { + let processed = data + .iter() + .map(|v| (v * 32768.0) as i16) + .collect::>(); + writer_2.lock().unwrap().extend_from_slice(&processed); + }, + move |err| { + eprintln!("an error occurred on stream: {}", err); + }, + None, + )?; + stream.play()?; + std::thread::sleep(std::time::Duration::from_millis(milliseconds)); + drop(stream); + let data = writer.lock().unwrap().clone(); + let step = 3; + let data: Vec = data.iter().step_by(step).copied().collect(); + Ok(data) +} diff --git a/candle-examples/examples/whisper-microphone/multilingual.rs b/candle-examples/examples/whisper-microphone/multilingual.rs new file mode 100644 index 00000000..053755cf --- /dev/null +++ b/candle-examples/examples/whisper-microphone/multilingual.rs @@ -0,0 +1,137 @@ +use crate::{token_id, Model}; +use candle::{IndexOp, Result, Tensor, D}; +use candle_transformers::models::whisper::{self as m}; +use tokenizers::Tokenizer; + +const LANGUAGES: [(&str, &str); 99] = [ + ("en", "english"), + ("zh", "chinese"), + ("de", "german"), + ("es", "spanish"), + ("ru", "russian"), + ("ko", "korean"), + ("fr", "french"), + ("ja", "japanese"), + ("pt", "portuguese"), + ("tr", "turkish"), + ("pl", "polish"), + ("ca", "catalan"), + ("nl", "dutch"), + ("ar", "arabic"), + ("sv", "swedish"), + ("it", "italian"), + ("id", "indonesian"), + ("hi", "hindi"), + ("fi", "finnish"), + ("vi", "vietnamese"), + ("he", "hebrew"), + ("uk", "ukrainian"), + ("el", "greek"), + ("ms", "malay"), + ("cs", "czech"), + ("ro", "romanian"), + ("da", "danish"), + ("hu", "hungarian"), + ("ta", "tamil"), + ("no", "norwegian"), + ("th", "thai"), + ("ur", "urdu"), + ("hr", "croatian"), + ("bg", "bulgarian"), + ("lt", "lithuanian"), + ("la", "latin"), + ("mi", "maori"), + ("ml", "malayalam"), + ("cy", "welsh"), + ("sk", "slovak"), + ("te", "telugu"), + ("fa", "persian"), + ("lv", "latvian"), + ("bn", "bengali"), + ("sr", "serbian"), + ("az", "azerbaijani"), + ("sl", "slovenian"), + ("kn", "kannada"), + ("et", "estonian"), + ("mk", "macedonian"), + ("br", "breton"), + ("eu", "basque"), + ("is", "icelandic"), + ("hy", "armenian"), + ("ne", "nepali"), + ("mn", "mongolian"), + ("bs", "bosnian"), + ("kk", "kazakh"), + ("sq", "albanian"), + ("sw", "swahili"), + ("gl", "galician"), + ("mr", "marathi"), + ("pa", "punjabi"), + ("si", "sinhala"), + ("km", "khmer"), + ("sn", "shona"), + ("yo", "yoruba"), + ("so", "somali"), + ("af", "afrikaans"), + ("oc", "occitan"), + ("ka", "georgian"), + ("be", "belarusian"), + ("tg", "tajik"), + ("sd", "sindhi"), + ("gu", "gujarati"), + ("am", "amharic"), + ("yi", "yiddish"), + ("lo", "lao"), + ("uz", "uzbek"), + ("fo", "faroese"), + ("ht", "haitian creole"), + ("ps", "pashto"), + ("tk", "turkmen"), + ("nn", "nynorsk"), + ("mt", "maltese"), + ("sa", "sanskrit"), + ("lb", "luxembourgish"), + ("my", "myanmar"), + ("bo", "tibetan"), + ("tl", "tagalog"), + ("mg", "malagasy"), + ("as", "assamese"), + ("tt", "tatar"), + ("haw", "hawaiian"), + ("ln", "lingala"), + ("ha", "hausa"), + ("ba", "bashkir"), + ("jw", "javanese"), + ("su", "sundanese"), +]; + +/// Returns the token id for the selected language. +pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result { + let (_bsize, _, seq_len) = mel.dims3()?; + let mel = mel.narrow( + 2, + 0, + usize::min(seq_len, model.config().max_source_positions), + )?; + let device = mel.device(); + let language_token_ids = LANGUAGES + .iter() + .map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>"))) + .collect::>>()?; + let sot_token = token_id(tokenizer, m::SOT_TOKEN)?; + let audio_features = model.encoder_forward(&mel, true)?; + let tokens = Tensor::new(&[[sot_token]], device)?; + let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; + let ys = model.decoder_forward(&tokens, &audio_features, true)?; + let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; + let logits = logits.index_select(&language_token_ids, 0)?; + let probs = candle_nn::ops::softmax(&logits, D::Minus1)?; + let probs = probs.to_vec1::()?; + let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::>(); + probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for ((_, language), p) in probs.iter().take(5) { + println!("{language}: {p}") + } + let language = token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?; + Ok(language) +} diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index 74f708e6..593ed373 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -129,6 +129,10 @@ impl MultiHeadAttention { .flatten_from(2)?; Ok(wv) } + + fn reset_kv_cache(&mut self) { + self.kv_cache = None; + } } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 @@ -193,6 +197,13 @@ impl ResidualAttentionBlock { )?; x + mlp } + + fn reset_kv_cache(&mut self) { + self.attn.reset_kv_cache(); + if let Some((attn, _)) = &mut self.cross_attn { + attn.reset_kv_cache(); + } + } } fn sinusoids(length: usize, channels: usize, device: &Device) -> Result { @@ -350,6 +361,12 @@ impl TextDecoder { }; Ok(logits) } + + pub fn reset_kv_cache(&mut self) { + for block in self.blocks.iter_mut() { + block.reset_kv_cache(); + } + } } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 @@ -370,4 +387,12 @@ impl Whisper { config, }) } + + pub fn reset_kv_cache(&mut self) { + self.encoder + .blocks + .iter_mut() + .for_each(|b| b.reset_kv_cache()); + self.decoder.reset_kv_cache(); + } } diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs index dac78be9..2db363c6 100644 --- a/candle-transformers/src/models/whisper/quantized_model.rs +++ b/candle-transformers/src/models/whisper/quantized_model.rs @@ -126,6 +126,10 @@ impl MultiHeadAttention { .flatten_from(2)?; Ok(wv) } + + fn reset_kv_cache(&mut self) { + self.kv_cache = None; + } } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 @@ -189,6 +193,13 @@ impl ResidualAttentionBlock { .apply(&self.mlp_linear2)?; x + mlp } + + fn reset_kv_cache(&mut self) { + self.attn.reset_kv_cache(); + if let Some((attn, _)) = &mut self.cross_attn { + attn.reset_kv_cache(); + } + } } fn sinusoids(length: usize, channels: usize, device: &Device) -> Result { @@ -281,6 +292,12 @@ impl AudioEncoder { let x = self.ln_post.forward(&x)?; Ok(x) } + + pub fn reset_kv_cache(&mut self) { + for block in self.blocks.iter_mut() { + block.reset_kv_cache(); + } + } } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 @@ -348,6 +365,12 @@ impl TextDecoder { }; Ok(logits) } + + pub fn reset_kv_cache(&mut self) { + for block in self.blocks.iter_mut() { + block.reset_kv_cache(); + } + } } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 @@ -368,4 +391,9 @@ impl Whisper { config, }) } + + pub fn reset_kv_cache(&mut self) { + self.encoder.reset_kv_cache(); + self.decoder.reset_kv_cache(); + } }