mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add an argument for the speaker encoder weights.
This commit is contained in:
@ -125,6 +125,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
second_stage_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
speaker_encoder_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
encodec_weights: Option<String>,
|
||||
|
||||
@ -239,20 +242,18 @@ fn main() -> Result<()> {
|
||||
Some(spk_emb) => spk_emb.to_dtype(dtype)?.to_device(&device)?,
|
||||
}
|
||||
} else {
|
||||
let weights = match &args.speaker_encoder_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("speaker_encoder.safetensors")?,
|
||||
};
|
||||
let mel_filters = mel_filters()?;
|
||||
let config = speaker_encoder::Config::cfg();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)? };
|
||||
let model = speaker_encoder::Model::new(config, vb)?;
|
||||
let (pcm, sample_rate) = pcm_decode(&args.spk_emb.unwrap())?;
|
||||
if sample_rate != 16_000 {
|
||||
eprintln!("WARNING: speaker embedding input should use a 16kHz sample rate!")
|
||||
}
|
||||
let mel_filters = mel_filters()?;
|
||||
let config = speaker_encoder::Config::cfg();
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(
|
||||
&["spk_emb_model.safetensors".to_string()],
|
||||
dtype,
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
let model = speaker_encoder::Model::new(config, vb)?;
|
||||
model.embed_utterance(
|
||||
&pcm,
|
||||
&mel_filters,
|
||||
|
Reference in New Issue
Block a user