mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
A few more tweaks.
This commit is contained in:
@ -188,7 +188,7 @@ fn main() -> Result<()> {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage.safetensors")?,
|
||||
};
|
||||
let second_stage_weights = match &args.first_stage_weights {
|
||||
let second_stage_weights = match &args.second_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("second_stage.safetensors")?,
|
||||
};
|
||||
|
@ -55,12 +55,8 @@ pub mod speaker_encoder {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(
|
||||
cfg.mel_n_channels,
|
||||
cfg.model_hidden_size,
|
||||
c,
|
||||
vb_l.pp(layer_idx),
|
||||
)?;
|
||||
let lstm =
|
||||
candle_nn::lstm(cfg.mel_n_channels, cfg.model_hidden_size, c, vb_l.clone())?;
|
||||
lstms.push(lstm)
|
||||
}
|
||||
let linear = linear_b(
|
||||
|
Reference in New Issue
Block a user