From 8cc0a183ba25d2de98463fa004314f7c5df1bfea Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 4 Mar 2024 14:13:01 +0100 Subject: [PATCH] Speaker embeddings computation for metavoice. (#1800) * Speaker embeddings computation for metavoice. * Compute the speaker embeddings. --- candle-transformers/src/models/metavoice.rs | 126 +++++++++++++++--- .../src/models/whisper/audio.rs | 2 +- 2 files changed, 107 insertions(+), 21 deletions(-) diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index 0ab19041..993f73ef 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -1,4 +1,4 @@ -use candle::{DType, Error as E, IndexOp, Module, Result, Tensor, D}; +use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D}; use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; // Equivalent to torch.repeat_interleave @@ -13,22 +13,41 @@ pub mod speaker_encoder { #[derive(Debug, Clone, serde::Deserialize)] pub struct Config { - pub mel_window_step: usize, - pub mel_n_channels: usize, pub sampling_rate: usize, pub partial_n_frames: usize, pub model_hidden_size: usize, pub model_embedding_size: usize, pub model_num_layers: usize, + pub mel_window_length: usize, + pub mel_window_step: usize, + pub mel_n_channels: usize, + } + + impl Config { + pub fn cfg() -> Self { + Self { + sampling_rate: 16_000, + partial_n_frames: 160, + model_hidden_size: 256, + model_embedding_size: 256, + model_num_layers: 3, + mel_window_length: 25, + mel_window_step: 10, + mel_n_channels: 40, + } + } } pub struct Model { lstms: Vec, linear: Linear, + cfg: Config, } + type Slice = (usize, usize); + impl Model { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + pub fn new(cfg: Config, vb: VarBuilder) -> Result { let mut lstms = Vec::with_capacity(cfg.model_num_layers); let vb_l = vb.pp("lstm"); for layer_idx in 0..cfg.model_num_layers { @@ -50,36 +69,103 @@ pub mod speaker_encoder { true, vb.pp("linear"), )?; - Ok(Self { lstms, linear }) + Ok(Self { lstms, linear, cfg }) } fn compute_partial_slices( - _n_samples: usize, - _rate: f64, - _min_coverage: f64, - ) -> Result<(Tensor, Tensor)> { - todo!() + &self, + n_samples: usize, + rate: f64, + min_coverage: f64, + ) -> (Vec, Vec) { + let c = &self.cfg; + // Compute how many frames separate two partial utterances + let samples_per_frame = c.sampling_rate * c.mel_window_step / 1000; + let n_frames = n_samples / samples_per_frame + 1; + let frame_step = + (c.sampling_rate as f64 / rate / samples_per_frame as f64).round() as usize; + let steps = (n_frames + frame_step).saturating_sub(c.partial_n_frames) + 1; + // Compute the slices. + let mut wav_slices = vec![]; + let mut mel_slices = vec![]; + for i in (0..steps).step_by(frame_step) { + let mel_range = (i, i + c.partial_n_frames); + let wav_range = ( + i * samples_per_frame, + (i + c.partial_n_frames) * samples_per_frame, + ); + mel_slices.push(mel_range); + wav_slices.push(wav_range); + } + // Evaluate whether extra padding is warranted or not. + let last_wav_range = match wav_slices.last() { + None => return (wav_slices, mel_slices), + Some(l) => *l, + }; + let coverage = (n_samples - last_wav_range.0) as f64 + / (last_wav_range.1 - last_wav_range.0) as f64; + if coverage > min_coverage && mel_slices.len() > 1 { + mel_slices.pop(); + wav_slices.pop(); + } + (wav_slices, mel_slices) } - pub fn embed_utterance(&self, wav: &[f32], rate: f64, min_coverage: f64) -> Result { - let (_wav_slices, _mel_slices) = - Self::compute_partial_slices(wav.len(), rate, min_coverage)?; - todo!() + pub fn embed_utterance( + &self, + wav: &[f32], + mel_filters: &[f32], + rate: f64, + min_c: f64, + device: &Device, + ) -> Result { + let (wav_slices, mel_slices) = self.compute_partial_slices(wav.len(), rate, min_c); + let max_wave_length = match wav_slices.last() { + Some(v) => v.1, + None => candle::bail!("empty wav slices"), + }; + let wav = if max_wave_length > wav.len() { + let mut wav = wav.to_vec(); + wav.resize(max_wave_length - wav.len(), 0.0); + std::borrow::Cow::Owned(wav) + } else { + std::borrow::Cow::Borrowed(wav) + }; + let mel = crate::models::whisper::audio::log_mel_spectrogram_( + wav.as_ref(), + mel_filters, + /* fft_size */ self.cfg.mel_window_length, + /* fft_step */ self.cfg.mel_window_step, + self.cfg.mel_n_channels, + false, + ); + let mels = mel_slices + .iter() + .flat_map(|s| [mel[s.0], mel[s.1]]) + .collect::>(); + let mels = Tensor::from_vec(mels, (mel_slices.len(), 2), device)?; + let partial_embeds = self.forward(&mels)?; + let raw_embed = partial_embeds.mean(0)?; + let norm = raw_embed.sqr()?.sum_all()?.sqrt()?; + raw_embed.broadcast_div(&norm) } } impl Module for Model { fn forward(&self, xs: &Tensor) -> Result { use candle_nn::RNN; + + // This is different from the Python transformers version as candle LSTM is batch first. + let xs = xs.t()?; let mut xs = xs.clone(); - for lstm in self.lstms.iter() { - let res = lstm.seq(&xs)?; - let res: Vec<_> = res.into_iter().map(|s| s.h().clone()).collect(); - xs = Tensor::stack(&res, 1)?; + for layer in self.lstms.iter() { + let states = layer.seq(&xs)?; + xs = layer.states_to_tensor(&states)?; } + let xs = xs.t()?; let embeds_raw = xs.apply(&self.linear)?.relu()?; - // TODO: normalize. - Ok(embeds_raw) + let norm = embeds_raw.sqr()?.sum_keepdim(1)?.sqrt()?; + embeds_raw.broadcast_div(&norm) } } } diff --git a/candle-transformers/src/models/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs index eb795f18..35f9f3df 100644 --- a/candle-transformers/src/models/whisper/audio.rs +++ b/candle-transformers/src/models/whisper/audio.rs @@ -167,7 +167,7 @@ fn log_mel_spectrogram_w( mel } -fn log_mel_spectrogram_( +pub fn log_mel_spectrogram_( samples: &[T], filters: &[T], fft_size: usize,