mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Speaker embeddings computation for metavoice. (#1800)
* Speaker embeddings computation for metavoice. * Compute the speaker embeddings.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use candle::{DType, Error as E, IndexOp, Module, Result, Tensor, D};
|
||||
use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
|
||||
// Equivalent to torch.repeat_interleave
|
||||
@ -13,22 +13,41 @@ pub mod speaker_encoder {
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub mel_window_step: usize,
|
||||
pub mel_n_channels: usize,
|
||||
pub sampling_rate: usize,
|
||||
pub partial_n_frames: usize,
|
||||
pub model_hidden_size: usize,
|
||||
pub model_embedding_size: usize,
|
||||
pub model_num_layers: usize,
|
||||
pub mel_window_length: usize,
|
||||
pub mel_window_step: usize,
|
||||
pub mel_n_channels: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn cfg() -> Self {
|
||||
Self {
|
||||
sampling_rate: 16_000,
|
||||
partial_n_frames: 160,
|
||||
model_hidden_size: 256,
|
||||
model_embedding_size: 256,
|
||||
model_num_layers: 3,
|
||||
mel_window_length: 25,
|
||||
mel_window_step: 10,
|
||||
mel_n_channels: 40,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Model {
|
||||
lstms: Vec<candle_nn::LSTM>,
|
||||
linear: Linear,
|
||||
cfg: Config,
|
||||
}
|
||||
|
||||
type Slice = (usize, usize);
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
|
||||
let mut lstms = Vec::with_capacity(cfg.model_num_layers);
|
||||
let vb_l = vb.pp("lstm");
|
||||
for layer_idx in 0..cfg.model_num_layers {
|
||||
@ -50,36 +69,103 @@ pub mod speaker_encoder {
|
||||
true,
|
||||
vb.pp("linear"),
|
||||
)?;
|
||||
Ok(Self { lstms, linear })
|
||||
Ok(Self { lstms, linear, cfg })
|
||||
}
|
||||
|
||||
fn compute_partial_slices(
|
||||
_n_samples: usize,
|
||||
_rate: f64,
|
||||
_min_coverage: f64,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
todo!()
|
||||
&self,
|
||||
n_samples: usize,
|
||||
rate: f64,
|
||||
min_coverage: f64,
|
||||
) -> (Vec<Slice>, Vec<Slice>) {
|
||||
let c = &self.cfg;
|
||||
// Compute how many frames separate two partial utterances
|
||||
let samples_per_frame = c.sampling_rate * c.mel_window_step / 1000;
|
||||
let n_frames = n_samples / samples_per_frame + 1;
|
||||
let frame_step =
|
||||
(c.sampling_rate as f64 / rate / samples_per_frame as f64).round() as usize;
|
||||
let steps = (n_frames + frame_step).saturating_sub(c.partial_n_frames) + 1;
|
||||
// Compute the slices.
|
||||
let mut wav_slices = vec![];
|
||||
let mut mel_slices = vec![];
|
||||
for i in (0..steps).step_by(frame_step) {
|
||||
let mel_range = (i, i + c.partial_n_frames);
|
||||
let wav_range = (
|
||||
i * samples_per_frame,
|
||||
(i + c.partial_n_frames) * samples_per_frame,
|
||||
);
|
||||
mel_slices.push(mel_range);
|
||||
wav_slices.push(wav_range);
|
||||
}
|
||||
// Evaluate whether extra padding is warranted or not.
|
||||
let last_wav_range = match wav_slices.last() {
|
||||
None => return (wav_slices, mel_slices),
|
||||
Some(l) => *l,
|
||||
};
|
||||
let coverage = (n_samples - last_wav_range.0) as f64
|
||||
/ (last_wav_range.1 - last_wav_range.0) as f64;
|
||||
if coverage > min_coverage && mel_slices.len() > 1 {
|
||||
mel_slices.pop();
|
||||
wav_slices.pop();
|
||||
}
|
||||
(wav_slices, mel_slices)
|
||||
}
|
||||
|
||||
pub fn embed_utterance(&self, wav: &[f32], rate: f64, min_coverage: f64) -> Result<Tensor> {
|
||||
let (_wav_slices, _mel_slices) =
|
||||
Self::compute_partial_slices(wav.len(), rate, min_coverage)?;
|
||||
todo!()
|
||||
pub fn embed_utterance(
|
||||
&self,
|
||||
wav: &[f32],
|
||||
mel_filters: &[f32],
|
||||
rate: f64,
|
||||
min_c: f64,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let (wav_slices, mel_slices) = self.compute_partial_slices(wav.len(), rate, min_c);
|
||||
let max_wave_length = match wav_slices.last() {
|
||||
Some(v) => v.1,
|
||||
None => candle::bail!("empty wav slices"),
|
||||
};
|
||||
let wav = if max_wave_length > wav.len() {
|
||||
let mut wav = wav.to_vec();
|
||||
wav.resize(max_wave_length - wav.len(), 0.0);
|
||||
std::borrow::Cow::Owned(wav)
|
||||
} else {
|
||||
std::borrow::Cow::Borrowed(wav)
|
||||
};
|
||||
let mel = crate::models::whisper::audio::log_mel_spectrogram_(
|
||||
wav.as_ref(),
|
||||
mel_filters,
|
||||
/* fft_size */ self.cfg.mel_window_length,
|
||||
/* fft_step */ self.cfg.mel_window_step,
|
||||
self.cfg.mel_n_channels,
|
||||
false,
|
||||
);
|
||||
let mels = mel_slices
|
||||
.iter()
|
||||
.flat_map(|s| [mel[s.0], mel[s.1]])
|
||||
.collect::<Vec<_>>();
|
||||
let mels = Tensor::from_vec(mels, (mel_slices.len(), 2), device)?;
|
||||
let partial_embeds = self.forward(&mels)?;
|
||||
let raw_embed = partial_embeds.mean(0)?;
|
||||
let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
|
||||
raw_embed.broadcast_div(&norm)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Model {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
|
||||
// This is different from the Python transformers version as candle LSTM is batch first.
|
||||
let xs = xs.t()?;
|
||||
let mut xs = xs.clone();
|
||||
for lstm in self.lstms.iter() {
|
||||
let res = lstm.seq(&xs)?;
|
||||
let res: Vec<_> = res.into_iter().map(|s| s.h().clone()).collect();
|
||||
xs = Tensor::stack(&res, 1)?;
|
||||
for layer in self.lstms.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
let xs = xs.t()?;
|
||||
let embeds_raw = xs.apply(&self.linear)?.relu()?;
|
||||
// TODO: normalize.
|
||||
Ok(embeds_raw)
|
||||
let norm = embeds_raw.sqr()?.sum_keepdim(1)?.sqrt()?;
|
||||
embeds_raw.broadcast_div(&norm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ fn log_mel_spectrogram_w<T: Float>(
|
||||
mel
|
||||
}
|
||||
|
||||
fn log_mel_spectrogram_<T: Float>(
|
||||
pub fn log_mel_spectrogram_<T: Float>(
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
fft_size: usize,
|
||||
|
Reference in New Issue
Block a user