From dae32d13d6099eaf011cdf665bd222fb89a30e8c Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 5 Mar 2024 22:19:30 +0100 Subject: [PATCH] Speaker embeddings for metavoice. --- candle-examples/Cargo.toml | 4 +- candle-examples/examples/metavoice/main.rs | 112 ++++++++++++++++-- .../examples/metavoice/melfilters40.bytes | Bin 0 -> 32160 bytes 3 files changed, 105 insertions(+), 11 deletions(-) create mode 100644 candle-examples/examples/metavoice/melfilters40.bytes diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index cb704f0c..d4f01804 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -100,4 +100,6 @@ required-features = ["candle-datasets"] name = "encodec" required-features = ["symphonia"] - +[[example]] +name = "metavoice" +required-features = ["symphonia"] diff --git a/candle-examples/examples/metavoice/main.rs b/candle-examples/examples/metavoice/main.rs index ae571929..923ed8bc 100644 --- a/candle-examples/examples/metavoice/main.rs +++ b/candle-examples/examples/metavoice/main.rs @@ -10,7 +10,9 @@ use std::io::Write; use candle_transformers::generation::LogitsProcessor; use candle_transformers::models::encodec; -use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer}; +use candle_transformers::models::metavoice::{ + adapters, gpt, speaker_encoder, tokenizers, transformer, +}; use candle::{DType, IndexOp, Tensor}; use candle_nn::VarBuilder; @@ -19,6 +21,60 @@ use rand::{distributions::Distribution, SeedableRng}; pub const ENCODEC_NTOKENS: u32 = 1024; +fn conv(samples: &mut Vec, data: std::borrow::Cow>) +where + T: symphonia::core::sample::Sample, + f32: symphonia::core::conv::FromSample, +{ + use symphonia::core::audio::Signal; + use symphonia::core::conv::FromSample; + samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v))) +} + +fn pcm_decode>(path: P) -> anyhow::Result<(Vec, u32)> { + use symphonia::core::audio::{AudioBufferRef, Signal}; + + let src = std::fs::File::open(path)?; + let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default()); + let hint = symphonia::core::probe::Hint::new(); + let meta_opts: symphonia::core::meta::MetadataOptions = Default::default(); + let fmt_opts: symphonia::core::formats::FormatOptions = Default::default(); + let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?; + let mut format = probed.format; + let track = format + .tracks() + .iter() + .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL) + .expect("no supported audio tracks"); + let mut decoder = symphonia::default::get_codecs() + .make(&track.codec_params, &Default::default()) + .expect("unsupported codec"); + let track_id = track.id; + let sample_rate = track.codec_params.sample_rate.unwrap_or(0); + let mut pcm_data = Vec::new(); + while let Ok(packet) = format.next_packet() { + while !format.metadata().is_latest() { + format.metadata().pop(); + } + if packet.track_id() != track_id { + continue; + } + match decoder.decode(&packet)? { + AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)), + AudioBufferRef::U8(data) => conv(&mut pcm_data, data), + AudioBufferRef::U16(data) => conv(&mut pcm_data, data), + AudioBufferRef::U24(data) => conv(&mut pcm_data, data), + AudioBufferRef::U32(data) => conv(&mut pcm_data, data), + AudioBufferRef::S8(data) => conv(&mut pcm_data, data), + AudioBufferRef::S16(data) => conv(&mut pcm_data, data), + AudioBufferRef::S24(data) => conv(&mut pcm_data, data), + AudioBufferRef::S32(data) => conv(&mut pcm_data, data), + AudioBufferRef::F64(data) => conv(&mut pcm_data, data), + } + } + Ok((pcm_data, sample_rate)) +} + #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] enum ArgDType { F32, @@ -72,6 +128,8 @@ struct Args { #[arg(long)] encodec_weights: Option, + /// The speaker embeddings, either an audio files in which case they are extracted, or a + /// safetensors file with the embeddings already extracted. #[arg(long)] spk_emb: Option, @@ -79,6 +137,13 @@ struct Args { dtype: ArgDType, } +fn mel_filters() -> Result> { + let mel_bytes = include_bytes!("melfilters40.bytes").as_slice(); + let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; + ::read_f32_into(mel_bytes, &mut mel_filters); + Ok(mel_filters) +} + fn main() -> Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -159,16 +224,43 @@ fn main() -> Result<()> { let prompt_tokens = fs_tokenizer.encode(&args.prompt)?; let mut tokens = prompt_tokens.clone(); println!("{tokens:?}"); - let spk_emb_file = match &args.spk_emb { - Some(w) => std::path::PathBuf::from(w), - None => repo.get("spk_emb.safetensors")?, + let safetensors_embeddings = args + .spk_emb + .as_ref() + .map_or(true, |v| v.ends_with("safetensors")); + let spk_emb = if safetensors_embeddings { + let spk_emb_file = match &args.spk_emb { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("spk_emb.safetensors")?, + }; + let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?; + match spk_emb.get("spk_emb") { + None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"), + Some(spk_emb) => spk_emb.to_dtype(dtype)?.to_device(&device)?, + } + } else { + let (pcm, sample_rate) = pcm_decode(&args.spk_emb.unwrap())?; + if sample_rate != 16_000 { + eprintln!("WARNING: speaker embedding input should use a 16kHz sample rate!") + } + let mel_filters = mel_filters()?; + let config = speaker_encoder::Config::cfg(); + let vb = unsafe { + VarBuilder::from_mmaped_safetensors( + &["spk_emb_model.safetensors".to_string()], + dtype, + &device, + )? + }; + let model = speaker_encoder::Model::new(config, vb)?; + model.embed_utterance( + &pcm, + &mel_filters, + /* rate */ 1.3, + /* min_c */ 0.75, + &device, + )? }; - let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?; - let spk_emb = match spk_emb.get("spk_emb") { - None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"), - Some(spk_emb) => spk_emb.to_dtype(dtype)?, - }; - let spk_emb = spk_emb.to_device(&device)?; let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95)); // First stage generation. diff --git a/candle-examples/examples/metavoice/melfilters40.bytes b/candle-examples/examples/metavoice/melfilters40.bytes new file mode 100644 index 0000000000000000000000000000000000000000..7b2a06b96f1560d15248bc3eff0ba2a19192d941 GIT binary patch literal 32160 zcmeI)`%jd276$NvASeRMwTyx=HlP(MbfhWsKIauF)?Rdnn}W0w>e_ng(k)mPEmRbA za4h^ptd5gQ-5e$W&VQqkdVxKn2?-L zo|Er+-weoPvc%8k-J=k3nJ5n^CK-RR01HeiFjA9`_336}?OAWJ_Dnp_%A`tMWC59M zq{e}a4HddG#YGggFY>H#C|E#`KvebzRQ!pT82_!MxMS97=y@BLSwQjvx*d!b552D2 zzaEiTlqN_XgaKGUkASkMm9DLQM~u^S^8TgAh|4VSi$GMime!>O=C|Jy`KgqMgkzfHm0&PhOT&*9bT?r(5m6!6Y=yBpQ3nUvCWByO^|N!QJZ-Sm^XV9vqJ(j|FX{h#91bqC7fPV$SD_9xpCSTi}P#P)v0U zAorjenqPK{I<^QV#?eAY-iA>94b0*~eHII!Y6R%d}!LK6A9mr$BEjShPTk$;8-b$9H+ z-Y{GK-xg29@wKD{%47b7Q;P@D@uL@gUA~dRPwpZAcXyC+f|{bqmTHHO!@AFy_cTeT zgqwd_z_H~ObnQHgtS);}-uxp)jwRBOq-2WjSVf5iZshd+J@{$Z zm_67nSX3t>ci?+iWjK(=^)1S8+C&3|v1Adxj0F4HbfNe*93Ap7@I3z=reuS{jXxux zEj=g%nMcF^y>mG8t|>jrRMLWmFberU$5so@(1=RGZZagDoZ4^Pi9KVx=_~SE|Z%k;!;3h6dHse@VE?l;)L4U9f7OPWvj(#V~ z7g7^2EFV%OR9_NiHN1pRo#XMBmM`EKa}F0>Zeigq6S8}4COtedmzEuVjtVQMkandO ze_wGP2r5CpdI#jY{czcFOt|CpXPzagg@^n7RiGxuOf{`MTe#G34Een!P$q6is7Vpx zvszGK*M~L!G8(b6q$#-$!Q^ov*Kk71Gfgco3EATM)nX3f|oP5Yt1a(cEL^ zBzH9+D-SJ(t=iFbzZNHY_9L(HZ+Pgp5~oaE@woW9u=C|4VVzJY2>f+o zyDH?^r@|-0-w9tDJ0dM-DWVl|sLdvW2_Dy)396+ilF5bOJ2*oESkqyUj& g*|5`W)%`3&S6`&7IqR;+1>JiDU0wG*{$#fP58-BtegFUf literal 0 HcmV?d00001