mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
More metavoice tweaks. (#1796)
This commit is contained in:
@ -6,6 +6,7 @@ extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use std::io::Write;
|
||||
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::models::encodec;
|
||||
@ -156,7 +157,7 @@ fn main() -> Result<()> {
|
||||
Some(spk_emb) => spk_emb.to_dtype(DType::F32)?,
|
||||
};
|
||||
let spk_emb = spk_emb.to_device(&device)?;
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
|
||||
|
||||
// First stage generation.
|
||||
for index in 0..args.max_tokens {
|
||||
@ -172,10 +173,13 @@ fn main() -> Result<()> {
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
print!(".");
|
||||
std::io::stdout().flush()?;
|
||||
if next_token == 2048 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
println!();
|
||||
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
|
||||
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
|
||||
println!("text ids len: {}", text_ids.len());
|
||||
|
@ -297,7 +297,7 @@ pub mod gpt {
|
||||
causal: false,
|
||||
target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025],
|
||||
swiglu_multiple_of: Some(256),
|
||||
norm_type: NormType::RMSNorm,
|
||||
norm_type: NormType::LayerNorm,
|
||||
kv_cache_enabled: false,
|
||||
attn_kernel_type: AttnKernelType::TorchAttn,
|
||||
spk_emb_on_text: true,
|
||||
|
Reference in New Issue
Block a user