mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Speaker embeddings computation for metavoice. (#1800)
* Speaker embeddings computation for metavoice. * Compute the speaker embeddings.
This commit is contained in:
@ -1,4 +1,4 @@
|
|||||||
use candle::{DType, Error as E, IndexOp, Module, Result, Tensor, D};
|
use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||||
|
|
||||||
// Equivalent to torch.repeat_interleave
|
// Equivalent to torch.repeat_interleave
|
||||||
@ -13,22 +13,41 @@ pub mod speaker_encoder {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub mel_window_step: usize,
|
|
||||||
pub mel_n_channels: usize,
|
|
||||||
pub sampling_rate: usize,
|
pub sampling_rate: usize,
|
||||||
pub partial_n_frames: usize,
|
pub partial_n_frames: usize,
|
||||||
pub model_hidden_size: usize,
|
pub model_hidden_size: usize,
|
||||||
pub model_embedding_size: usize,
|
pub model_embedding_size: usize,
|
||||||
pub model_num_layers: usize,
|
pub model_num_layers: usize,
|
||||||
|
pub mel_window_length: usize,
|
||||||
|
pub mel_window_step: usize,
|
||||||
|
pub mel_n_channels: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn cfg() -> Self {
|
||||||
|
Self {
|
||||||
|
sampling_rate: 16_000,
|
||||||
|
partial_n_frames: 160,
|
||||||
|
model_hidden_size: 256,
|
||||||
|
model_embedding_size: 256,
|
||||||
|
model_num_layers: 3,
|
||||||
|
mel_window_length: 25,
|
||||||
|
mel_window_step: 10,
|
||||||
|
mel_n_channels: 40,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
lstms: Vec<candle_nn::LSTM>,
|
lstms: Vec<candle_nn::LSTM>,
|
||||||
linear: Linear,
|
linear: Linear,
|
||||||
|
cfg: Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Slice = (usize, usize);
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let mut lstms = Vec::with_capacity(cfg.model_num_layers);
|
let mut lstms = Vec::with_capacity(cfg.model_num_layers);
|
||||||
let vb_l = vb.pp("lstm");
|
let vb_l = vb.pp("lstm");
|
||||||
for layer_idx in 0..cfg.model_num_layers {
|
for layer_idx in 0..cfg.model_num_layers {
|
||||||
@ -50,36 +69,103 @@ pub mod speaker_encoder {
|
|||||||
true,
|
true,
|
||||||
vb.pp("linear"),
|
vb.pp("linear"),
|
||||||
)?;
|
)?;
|
||||||
Ok(Self { lstms, linear })
|
Ok(Self { lstms, linear, cfg })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_partial_slices(
|
fn compute_partial_slices(
|
||||||
_n_samples: usize,
|
&self,
|
||||||
_rate: f64,
|
n_samples: usize,
|
||||||
_min_coverage: f64,
|
rate: f64,
|
||||||
) -> Result<(Tensor, Tensor)> {
|
min_coverage: f64,
|
||||||
todo!()
|
) -> (Vec<Slice>, Vec<Slice>) {
|
||||||
|
let c = &self.cfg;
|
||||||
|
// Compute how many frames separate two partial utterances
|
||||||
|
let samples_per_frame = c.sampling_rate * c.mel_window_step / 1000;
|
||||||
|
let n_frames = n_samples / samples_per_frame + 1;
|
||||||
|
let frame_step =
|
||||||
|
(c.sampling_rate as f64 / rate / samples_per_frame as f64).round() as usize;
|
||||||
|
let steps = (n_frames + frame_step).saturating_sub(c.partial_n_frames) + 1;
|
||||||
|
// Compute the slices.
|
||||||
|
let mut wav_slices = vec![];
|
||||||
|
let mut mel_slices = vec![];
|
||||||
|
for i in (0..steps).step_by(frame_step) {
|
||||||
|
let mel_range = (i, i + c.partial_n_frames);
|
||||||
|
let wav_range = (
|
||||||
|
i * samples_per_frame,
|
||||||
|
(i + c.partial_n_frames) * samples_per_frame,
|
||||||
|
);
|
||||||
|
mel_slices.push(mel_range);
|
||||||
|
wav_slices.push(wav_range);
|
||||||
|
}
|
||||||
|
// Evaluate whether extra padding is warranted or not.
|
||||||
|
let last_wav_range = match wav_slices.last() {
|
||||||
|
None => return (wav_slices, mel_slices),
|
||||||
|
Some(l) => *l,
|
||||||
|
};
|
||||||
|
let coverage = (n_samples - last_wav_range.0) as f64
|
||||||
|
/ (last_wav_range.1 - last_wav_range.0) as f64;
|
||||||
|
if coverage > min_coverage && mel_slices.len() > 1 {
|
||||||
|
mel_slices.pop();
|
||||||
|
wav_slices.pop();
|
||||||
|
}
|
||||||
|
(wav_slices, mel_slices)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_utterance(&self, wav: &[f32], rate: f64, min_coverage: f64) -> Result<Tensor> {
|
pub fn embed_utterance(
|
||||||
let (_wav_slices, _mel_slices) =
|
&self,
|
||||||
Self::compute_partial_slices(wav.len(), rate, min_coverage)?;
|
wav: &[f32],
|
||||||
todo!()
|
mel_filters: &[f32],
|
||||||
|
rate: f64,
|
||||||
|
min_c: f64,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (wav_slices, mel_slices) = self.compute_partial_slices(wav.len(), rate, min_c);
|
||||||
|
let max_wave_length = match wav_slices.last() {
|
||||||
|
Some(v) => v.1,
|
||||||
|
None => candle::bail!("empty wav slices"),
|
||||||
|
};
|
||||||
|
let wav = if max_wave_length > wav.len() {
|
||||||
|
let mut wav = wav.to_vec();
|
||||||
|
wav.resize(max_wave_length - wav.len(), 0.0);
|
||||||
|
std::borrow::Cow::Owned(wav)
|
||||||
|
} else {
|
||||||
|
std::borrow::Cow::Borrowed(wav)
|
||||||
|
};
|
||||||
|
let mel = crate::models::whisper::audio::log_mel_spectrogram_(
|
||||||
|
wav.as_ref(),
|
||||||
|
mel_filters,
|
||||||
|
/* fft_size */ self.cfg.mel_window_length,
|
||||||
|
/* fft_step */ self.cfg.mel_window_step,
|
||||||
|
self.cfg.mel_n_channels,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
let mels = mel_slices
|
||||||
|
.iter()
|
||||||
|
.flat_map(|s| [mel[s.0], mel[s.1]])
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let mels = Tensor::from_vec(mels, (mel_slices.len(), 2), device)?;
|
||||||
|
let partial_embeds = self.forward(&mels)?;
|
||||||
|
let raw_embed = partial_embeds.mean(0)?;
|
||||||
|
let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
|
||||||
|
raw_embed.broadcast_div(&norm)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Model {
|
impl Module for Model {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
use candle_nn::RNN;
|
use candle_nn::RNN;
|
||||||
|
|
||||||
|
// This is different from the Python transformers version as candle LSTM is batch first.
|
||||||
|
let xs = xs.t()?;
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for lstm in self.lstms.iter() {
|
for layer in self.lstms.iter() {
|
||||||
let res = lstm.seq(&xs)?;
|
let states = layer.seq(&xs)?;
|
||||||
let res: Vec<_> = res.into_iter().map(|s| s.h().clone()).collect();
|
xs = layer.states_to_tensor(&states)?;
|
||||||
xs = Tensor::stack(&res, 1)?;
|
|
||||||
}
|
}
|
||||||
|
let xs = xs.t()?;
|
||||||
let embeds_raw = xs.apply(&self.linear)?.relu()?;
|
let embeds_raw = xs.apply(&self.linear)?.relu()?;
|
||||||
// TODO: normalize.
|
let norm = embeds_raw.sqr()?.sum_keepdim(1)?.sqrt()?;
|
||||||
Ok(embeds_raw)
|
embeds_raw.broadcast_div(&norm)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -167,7 +167,7 @@ fn log_mel_spectrogram_w<T: Float>(
|
|||||||
mel
|
mel
|
||||||
}
|
}
|
||||||
|
|
||||||
fn log_mel_spectrogram_<T: Float>(
|
pub fn log_mel_spectrogram_<T: Float>(
|
||||||
samples: &[T],
|
samples: &[T],
|
||||||
filters: &[T],
|
filters: &[T],
|
||||||
fft_size: usize,
|
fft_size: usize,
|
||||||
|
Reference in New Issue
Block a user