mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Metavoice position fix (#1791)
* Add the metavoice transformer. * Sketch the speaker-encoder module. * Adding to the metavoice model. * Start adding the metavoice example. * Get some logits out. * Load the second stage model. * Get the second step to run. * Tweak the example. * Add encodec tilting. * Glue the different bits together. * Fix a shape issue. * Use a constant. * BPE tokenization. * Fix the position index in metavoice.
This commit is contained in:
@ -44,6 +44,10 @@ struct Args {
|
|||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
|
/// The maximum number of tokens to generate for the first stage.
|
||||||
|
#[arg(long, default_value_t = 2000)]
|
||||||
|
max_tokens: u64,
|
||||||
|
|
||||||
/// The output file using the wav format.
|
/// The output file using the wav format.
|
||||||
#[arg(long, default_value = "out.wav")]
|
#[arg(long, default_value = "out.wav")]
|
||||||
out_file: String,
|
out_file: String,
|
||||||
@ -148,19 +152,18 @@ fn main() -> Result<()> {
|
|||||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None);
|
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None);
|
||||||
|
|
||||||
// First stage generation.
|
// First stage generation.
|
||||||
for index in 0.. {
|
for index in 0..args.max_tokens {
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
let ctxt = &tokens[start_pos..];
|
let ctxt = &tokens[start_pos..];
|
||||||
let input = Tensor::new(ctxt, &device)?;
|
let input = Tensor::new(ctxt, &device)?;
|
||||||
let input = Tensor::stack(&[&input, &input], 0)?;
|
let input = Tensor::stack(&[&input, &input], 0)?;
|
||||||
let logits = first_stage_model.forward(&input, &spk_emb, index)?;
|
let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?;
|
||||||
let logits0 = logits.i((0, 0))?;
|
let logits0 = logits.i((0, 0))?;
|
||||||
let logits1 = logits.i((1, 0))?;
|
let logits1 = logits.i((1, 0))?;
|
||||||
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
|
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
|
||||||
let logits = logits.to_dtype(DType::F32)?;
|
let logits = logits.to_dtype(DType::F32)?;
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
println!("{} {next_token}", tokens.len());
|
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
if next_token == 2048 {
|
if next_token == 2048 {
|
||||||
break;
|
break;
|
||||||
@ -183,9 +186,9 @@ fn main() -> Result<()> {
|
|||||||
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
|
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
|
||||||
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
|
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
|
||||||
let logits = second_stage_model.forward(&in_x)?;
|
let logits = second_stage_model.forward(&in_x)?;
|
||||||
|
println!("sampling from logits...");
|
||||||
let mut codes = vec![];
|
let mut codes = vec![];
|
||||||
for (idx, logits) in logits.iter().enumerate() {
|
for logits in logits.iter() {
|
||||||
println!("{idx} {logits}");
|
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let (seq_len, _) = logits.dims2()?;
|
let (seq_len, _) = logits.dims2()?;
|
||||||
let mut codes_ = Vec::with_capacity(seq_len);
|
let mut codes_ = Vec::with_capacity(seq_len);
|
||||||
|
Reference in New Issue
Block a user