mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
More metavoice tweaks. (#1796)
This commit is contained in:
@ -6,6 +6,7 @@ extern crate accelerate_src;
|
|||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use candle_transformers::models::encodec;
|
use candle_transformers::models::encodec;
|
||||||
@ -156,7 +157,7 @@ fn main() -> Result<()> {
|
|||||||
Some(spk_emb) => spk_emb.to_dtype(DType::F32)?,
|
Some(spk_emb) => spk_emb.to_dtype(DType::F32)?,
|
||||||
};
|
};
|
||||||
let spk_emb = spk_emb.to_device(&device)?;
|
let spk_emb = spk_emb.to_device(&device)?;
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None);
|
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
|
||||||
|
|
||||||
// First stage generation.
|
// First stage generation.
|
||||||
for index in 0..args.max_tokens {
|
for index in 0..args.max_tokens {
|
||||||
@ -172,10 +173,13 @@ fn main() -> Result<()> {
|
|||||||
let logits = logits.to_dtype(DType::F32)?;
|
let logits = logits.to_dtype(DType::F32)?;
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
|
print!(".");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
if next_token == 2048 {
|
if next_token == 2048 {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
println!();
|
||||||
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
|
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
|
||||||
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
|
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
|
||||||
println!("text ids len: {}", text_ids.len());
|
println!("text ids len: {}", text_ids.len());
|
||||||
|
@ -297,7 +297,7 @@ pub mod gpt {
|
|||||||
causal: false,
|
causal: false,
|
||||||
target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025],
|
target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025],
|
||||||
swiglu_multiple_of: Some(256),
|
swiglu_multiple_of: Some(256),
|
||||||
norm_type: NormType::RMSNorm,
|
norm_type: NormType::LayerNorm,
|
||||||
kv_cache_enabled: false,
|
kv_cache_enabled: false,
|
||||||
attn_kernel_type: AttnKernelType::TorchAttn,
|
attn_kernel_type: AttnKernelType::TorchAttn,
|
||||||
spk_emb_on_text: true,
|
spk_emb_on_text: true,
|
||||||
|
Reference in New Issue
Block a user