Add an argument for the speaker encoder weights.

This commit is contained in:
laurent
2024-03-05 22:33:41 +01:00
parent dae32d13d6
commit 52ed77c16f

View File

@ -125,6 +125,9 @@ struct Args {
#[arg(long)]
second_stage_weights: Option<String>,
#[arg(long)]
speaker_encoder_weights: Option<String>,
#[arg(long)]
encodec_weights: Option<String>,
@ -239,20 +242,18 @@ fn main() -> Result<()> {
Some(spk_emb) => spk_emb.to_dtype(dtype)?.to_device(&device)?,
}
} else {
let weights = match &args.speaker_encoder_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("speaker_encoder.safetensors")?,
};
let mel_filters = mel_filters()?;
let config = speaker_encoder::Config::cfg();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)? };
let model = speaker_encoder::Model::new(config, vb)?;
let (pcm, sample_rate) = pcm_decode(&args.spk_emb.unwrap())?;
if sample_rate != 16_000 {
eprintln!("WARNING: speaker embedding input should use a 16kHz sample rate!")
}
let mel_filters = mel_filters()?;
let config = speaker_encoder::Config::cfg();
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
&["spk_emb_model.safetensors".to_string()],
dtype,
&device,
)?
};
let model = speaker_encoder::Model::new(config, vb)?;
model.embed_utterance(
&pcm,
&mel_filters,