mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add the DAC model. (#2433)
* Add the DAC model. * More quantization support. * Handle DAC decoding. * Plug the DAC decoding in parler-tts.
This commit is contained in:
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::parler_tts::{Config, Model};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -36,7 +36,7 @@ struct Args {
|
||||
description: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
#[arg(long, default_value_t = 0.0)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
@ -82,6 +82,10 @@ struct Args {
|
||||
|
||||
#[arg(long, default_value_t = 512)]
|
||||
max_steps: usize,
|
||||
|
||||
/// The output wav file.
|
||||
#[arg(long, default_value = "out.wav")]
|
||||
out_file: String,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
@ -152,24 +156,32 @@ fn main() -> anyhow::Result<()> {
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
|
||||
println!("{description_tokens}");
|
||||
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
|
||||
println!("{prompt_tokens}");
|
||||
|
||||
let lp = candle_transformers::generation::LogitsProcessor::new(
|
||||
args.seed,
|
||||
Some(args.temperature),
|
||||
args.top_p,
|
||||
);
|
||||
println!("starting generation...");
|
||||
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
|
||||
println!("{codes}");
|
||||
println!("generated codes\n{codes}");
|
||||
let codes = codes.to_dtype(DType::I64)?;
|
||||
codes.save_safetensors("codes", "out.safetensors")?;
|
||||
let codes = codes.unsqueeze(0)?;
|
||||
let pcm = model
|
||||
.audio_encoder
|
||||
.decode_codes(&codes.to_device(&device)?)?;
|
||||
println!("{pcm}");
|
||||
let pcm = pcm.i((0, 0))?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, config.audio_encoder.sampling_rate)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user