diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index cfc8c473..b3516d21 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -143,7 +143,7 @@ pub mod speaker_encoder { .iter() .flat_map(|s| [mel[s.0], mel[s.1]]) .collect::>(); - let mels = Tensor::from_vec(mels, (mel_slices.len(), 2), device)?; + let mels = Tensor::from_vec(mels, (1, mel_slices.len(), 2), device)?; let partial_embeds = self.forward(&mels)?; let raw_embed = partial_embeds.mean(0)?; let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;