mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Metavoice position fix (#1791)
* Add the metavoice transformer. * Sketch the speaker-encoder module. * Adding to the metavoice model. * Start adding the metavoice example. * Get some logits out. * Load the second stage model. * Get the second step to run. * Tweak the example. * Add encodec tilting. * Glue the different bits together. * Fix a shape issue. * Use a constant. * BPE tokenization. * Fix the position index in metavoice.
This commit is contained in:
@ -44,6 +44,10 @@ struct Args {
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The maximum number of tokens to generate for the first stage.
|
||||
#[arg(long, default_value_t = 2000)]
|
||||
max_tokens: u64,
|
||||
|
||||
/// The output file using the wav format.
|
||||
#[arg(long, default_value = "out.wav")]
|
||||
out_file: String,
|
||||
@ -148,19 +152,18 @@ fn main() -> Result<()> {
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None);
|
||||
|
||||
// First stage generation.
|
||||
for index in 0.. {
|
||||
for index in 0..args.max_tokens {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let input = Tensor::stack(&[&input, &input], 0)?;
|
||||
let logits = first_stage_model.forward(&input, &spk_emb, index)?;
|
||||
let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?;
|
||||
let logits0 = logits.i((0, 0))?;
|
||||
let logits1 = logits.i((1, 0))?;
|
||||
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
println!("{} {next_token}", tokens.len());
|
||||
tokens.push(next_token);
|
||||
if next_token == 2048 {
|
||||
break;
|
||||
@ -183,9 +186,9 @@ fn main() -> Result<()> {
|
||||
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
|
||||
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
|
||||
let logits = second_stage_model.forward(&in_x)?;
|
||||
println!("sampling from logits...");
|
||||
let mut codes = vec![];
|
||||
for (idx, logits) in logits.iter().enumerate() {
|
||||
println!("{idx} {logits}");
|
||||
for logits in logits.iter() {
|
||||
let logits = logits.squeeze(0)?;
|
||||
let (seq_len, _) = logits.dims2()?;
|
||||
let mut codes_ = Vec::with_capacity(seq_len);
|
||||
|
Reference in New Issue
Block a user