More metavoice tweaks. (#1796)

This commit is contained in:
Laurent Mazare
2024-03-03 15:05:25 +01:00
committed by GitHub
parent 20abb72fec
commit 60dc72b96b
2 changed files with 6 additions and 2 deletions

View File

@ -6,6 +6,7 @@ extern crate accelerate_src;
use anyhow::Result;
use clap::Parser;
use std::io::Write;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::encodec;
@ -156,7 +157,7 @@ fn main() -> Result<()> {
Some(spk_emb) => spk_emb.to_dtype(DType::F32)?,
};
let spk_emb = spk_emb.to_device(&device)?;
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None);
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
// First stage generation.
for index in 0..args.max_tokens {
@ -172,10 +173,13 @@ fn main() -> Result<()> {
let logits = logits.to_dtype(DType::F32)?;
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
print!(".");
std::io::stdout().flush()?;
if next_token == 2048 {
break;
}
}
println!();
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
println!("text ids len: {}", text_ids.len());

View File

@ -297,7 +297,7 @@ pub mod gpt {
causal: false,
target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025],
swiglu_multiple_of: Some(256),
norm_type: NormType::RMSNorm,
norm_type: NormType::LayerNorm,
kv_cache_enabled: false,
attn_kernel_type: AttnKernelType::TorchAttn,
spk_emb_on_text: true,