From de11623752edbeb42c233256dfc83f56b688e61b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 2 Mar 2024 21:00:35 +0100 Subject: [PATCH] Metavoice position fix (#1791) * Add the metavoice transformer. * Sketch the speaker-encoder module. * Adding to the metavoice model. * Start adding the metavoice example. * Get some logits out. * Load the second stage model. * Get the second step to run. * Tweak the example. * Add encodec tilting. * Glue the different bits together. * Fix a shape issue. * Use a constant. * BPE tokenization. * Fix the position index in metavoice. --- candle-examples/examples/metavoice/main.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/candle-examples/examples/metavoice/main.rs b/candle-examples/examples/metavoice/main.rs index 6788976a..62205495 100644 --- a/candle-examples/examples/metavoice/main.rs +++ b/candle-examples/examples/metavoice/main.rs @@ -44,6 +44,10 @@ struct Args { #[arg(long, default_value_t = 299792458)] seed: u64, + /// The maximum number of tokens to generate for the first stage. + #[arg(long, default_value_t = 2000)] + max_tokens: u64, + /// The output file using the wav format. #[arg(long, default_value = "out.wav")] out_file: String, @@ -148,19 +152,18 @@ fn main() -> Result<()> { let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None); // First stage generation. - for index in 0.. { + for index in 0..args.max_tokens { let context_size = if index > 0 { 1 } else { tokens.len() }; let start_pos = tokens.len().saturating_sub(context_size); let ctxt = &tokens[start_pos..]; let input = Tensor::new(ctxt, &device)?; let input = Tensor::stack(&[&input, &input], 0)?; - let logits = first_stage_model.forward(&input, &spk_emb, index)?; + let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?; let logits0 = logits.i((0, 0))?; let logits1 = logits.i((1, 0))?; let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?; let logits = logits.to_dtype(DType::F32)?; let next_token = logits_processor.sample(&logits)?; - println!("{} {next_token}", tokens.len()); tokens.push(next_token); if next_token == 2048 { break; @@ -183,9 +186,9 @@ fn main() -> Result<()> { let in_x2 = Tensor::new(hierarchies_in2, &device)?; let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?; let logits = second_stage_model.forward(&in_x)?; + println!("sampling from logits..."); let mut codes = vec![]; - for (idx, logits) in logits.iter().enumerate() { - println!("{idx} {logits}"); + for logits in logits.iter() { let logits = logits.squeeze(0)?; let (seq_len, _) = logits.dims2()?; let mut codes_ = Vec::with_capacity(seq_len);