From 60dc72b96ba6968c7c3af197a93115409126be50 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 3 Mar 2024 15:05:25 +0100 Subject: [PATCH] More metavoice tweaks. (#1796) --- candle-examples/examples/metavoice/main.rs | 6 +++++- candle-transformers/src/models/metavoice.rs | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/candle-examples/examples/metavoice/main.rs b/candle-examples/examples/metavoice/main.rs index ae751b39..ef6c8079 100644 --- a/candle-examples/examples/metavoice/main.rs +++ b/candle-examples/examples/metavoice/main.rs @@ -6,6 +6,7 @@ extern crate accelerate_src; use anyhow::Result; use clap::Parser; +use std::io::Write; use candle_transformers::generation::LogitsProcessor; use candle_transformers::models::encodec; @@ -156,7 +157,7 @@ fn main() -> Result<()> { Some(spk_emb) => spk_emb.to_dtype(DType::F32)?, }; let spk_emb = spk_emb.to_device(&device)?; - let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None); + let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95)); // First stage generation. for index in 0..args.max_tokens { @@ -172,10 +173,13 @@ fn main() -> Result<()> { let logits = logits.to_dtype(DType::F32)?; let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); + print!("."); + std::io::stdout().flush()?; if next_token == 2048 { break; } } + println!(); let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS); let (text_ids, ids1, ids2) = fie2c.decode(&tokens); println!("text ids len: {}", text_ids.len()); diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index e37c168c..0ab19041 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -297,7 +297,7 @@ pub mod gpt { causal: false, target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025], swiglu_multiple_of: Some(256), - norm_type: NormType::RMSNorm, + norm_type: NormType::LayerNorm, kv_cache_enabled: false, attn_kernel_type: AttnKernelType::TorchAttn, spk_emb_on_text: true,