Add the DAC model. (#2433)

* Add the DAC model.

* More quantization support.

* Handle DAC decoding.

* Plug the DAC decoding in parler-tts.
This commit is contained in:
Laurent Mazare
2024-08-19 07:59:51 +01:00
committed by GitHub
parent 58197e1896
commit 236b29ff15
7 changed files with 404 additions and 8 deletions

View File

@ -7,7 +7,7 @@ extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Tensor};
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::parler_tts::{Config, Model};
use tokenizers::Tokenizer;
@ -36,7 +36,7 @@ struct Args {
description: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 1.0)]
#[arg(long, default_value_t = 0.0)]
temperature: f64,
/// Nucleus sampling probability cutoff.
@ -82,6 +82,10 @@ struct Args {
#[arg(long, default_value_t = 512)]
max_steps: usize,
/// The output wav file.
#[arg(long, default_value = "out.wav")]
out_file: String,
}
fn main() -> anyhow::Result<()> {
@ -152,24 +156,32 @@ fn main() -> anyhow::Result<()> {
.get_ids()
.to_vec();
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
println!("{description_tokens}");
let prompt_tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
println!("{prompt_tokens}");
let lp = candle_transformers::generation::LogitsProcessor::new(
args.seed,
Some(args.temperature),
args.top_p,
);
println!("starting generation...");
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
println!("{codes}");
println!("generated codes\n{codes}");
let codes = codes.to_dtype(DType::I64)?;
codes.save_safetensors("codes", "out.safetensors")?;
let codes = codes.unsqueeze(0)?;
let pcm = model
.audio_encoder
.decode_codes(&codes.to_device(&device)?)?;
println!("{pcm}");
let pcm = pcm.i((0, 0))?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, config.audio_encoder.sampling_rate)?;
Ok(())
}