From 52ed77c16f8e65b07835ec5e71c613e2dbb1167f Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 5 Mar 2024 22:33:41 +0100 Subject: [PATCH] Add an argument for the speaker encoder weights. --- candle-examples/examples/metavoice/main.rs | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/candle-examples/examples/metavoice/main.rs b/candle-examples/examples/metavoice/main.rs index 923ed8bc..72b1d39a 100644 --- a/candle-examples/examples/metavoice/main.rs +++ b/candle-examples/examples/metavoice/main.rs @@ -125,6 +125,9 @@ struct Args { #[arg(long)] second_stage_weights: Option, + #[arg(long)] + speaker_encoder_weights: Option, + #[arg(long)] encodec_weights: Option, @@ -239,20 +242,18 @@ fn main() -> Result<()> { Some(spk_emb) => spk_emb.to_dtype(dtype)?.to_device(&device)?, } } else { + let weights = match &args.speaker_encoder_weights { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("speaker_encoder.safetensors")?, + }; + let mel_filters = mel_filters()?; + let config = speaker_encoder::Config::cfg(); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)? }; + let model = speaker_encoder::Model::new(config, vb)?; let (pcm, sample_rate) = pcm_decode(&args.spk_emb.unwrap())?; if sample_rate != 16_000 { eprintln!("WARNING: speaker embedding input should use a 16kHz sample rate!") } - let mel_filters = mel_filters()?; - let config = speaker_encoder::Config::cfg(); - let vb = unsafe { - VarBuilder::from_mmaped_safetensors( - &["spk_emb_model.safetensors".to_string()], - dtype, - &device, - )? - }; - let model = speaker_encoder::Model::new(config, vb)?; model.embed_utterance( &pcm, &mel_filters,