diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index 2381c594..cfc8c473 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -55,8 +55,12 @@ pub mod speaker_encoder { layer_idx, ..Default::default() }; - let lstm = - candle_nn::lstm(cfg.mel_n_channels, cfg.model_hidden_size, c, vb_l.clone())?; + let in_c = if layer_idx == 0 { + cfg.mel_n_channels + } else { + cfg.model_hidden_size + }; + let lstm = candle_nn::lstm(in_c, cfg.model_hidden_size, c, vb_l.clone())?; lstms.push(lstm) } let linear = linear_b(