mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add the DAC model. (#2433)
* Add the DAC model. * More quantization support. * Handle DAC decoding. * Plug the DAC decoding in parler-tts.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -42,3 +42,4 @@ candle-wasm-examples/**/config*.json
|
||||
.idea/*
|
||||
__pycache__
|
||||
out.safetensors
|
||||
out.wav
|
||||
|
@ -5,6 +5,7 @@ from parler_tts import DACModel
|
||||
|
||||
tensors = load_file("out.safetensors")
|
||||
dac_model = DACModel.from_pretrained("parler-tts/dac_44khZ_8kbps")
|
||||
print(dac_model.model)
|
||||
output_ids = tensors["codes"][None, None]
|
||||
print(output_ids, "\n", output_ids.shape)
|
||||
batch_size = 1
|
||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::parler_tts::{Config, Model};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -36,7 +36,7 @@ struct Args {
|
||||
description: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
#[arg(long, default_value_t = 0.0)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
@ -82,6 +82,10 @@ struct Args {
|
||||
|
||||
#[arg(long, default_value_t = 512)]
|
||||
max_steps: usize,
|
||||
|
||||
/// The output wav file.
|
||||
#[arg(long, default_value = "out.wav")]
|
||||
out_file: String,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
@ -152,24 +156,32 @@ fn main() -> anyhow::Result<()> {
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
|
||||
println!("{description_tokens}");
|
||||
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
|
||||
println!("{prompt_tokens}");
|
||||
|
||||
let lp = candle_transformers::generation::LogitsProcessor::new(
|
||||
args.seed,
|
||||
Some(args.temperature),
|
||||
args.top_p,
|
||||
);
|
||||
println!("starting generation...");
|
||||
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
|
||||
println!("{codes}");
|
||||
println!("generated codes\n{codes}");
|
||||
let codes = codes.to_dtype(DType::I64)?;
|
||||
codes.save_safetensors("codes", "out.safetensors")?;
|
||||
let codes = codes.unsqueeze(0)?;
|
||||
let pcm = model
|
||||
.audio_encoder
|
||||
.decode_codes(&codes.to_device(&device)?)?;
|
||||
println!("{pcm}");
|
||||
let pcm = pcm.i((0, 0))?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, config.audio_encoder.sampling_rate)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
376
candle-transformers/src/models/dac.rs
Normal file
376
candle-transformers/src/models/dac.rs
Normal file
@ -0,0 +1,376 @@
|
||||
/// Adapted from https://github.com/descriptinc/descript-audio-codec
|
||||
use crate::models::encodec;
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, VarBuilder};
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub num_codebooks: usize,
|
||||
pub model_bitrate: u32,
|
||||
pub codebook_size: usize,
|
||||
pub latent_dim: usize,
|
||||
pub frame_rate: u32,
|
||||
pub sampling_rate: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Snake1d {
|
||||
alpha: Tensor,
|
||||
}
|
||||
|
||||
impl Snake1d {
|
||||
pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let alpha = vb.get((1, channels, 1), "alpha")?;
|
||||
Ok(Self { alpha })
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for Snake1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs_shape = xs.shape();
|
||||
let xs = xs.flatten_from(2)?;
|
||||
let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
|
||||
let sin = (&sin * &sin)?;
|
||||
(xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResidualUnit {
|
||||
snake1: Snake1d,
|
||||
conv1: Conv1d,
|
||||
snake2: Snake1d,
|
||||
conv2: Conv1d,
|
||||
}
|
||||
|
||||
impl ResidualUnit {
|
||||
pub fn new(dim: usize, dilation: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let pad = ((7 - 1) * dilation) / 2;
|
||||
let vb = vb.pp("block");
|
||||
let snake1 = Snake1d::new(dim, vb.pp(0))?;
|
||||
let cfg1 = Conv1dConfig {
|
||||
dilation,
|
||||
padding: pad,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = encodec::conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?;
|
||||
let snake2 = Snake1d::new(dim, vb.pp(2))?;
|
||||
let conv2 = encodec::conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?;
|
||||
Ok(Self {
|
||||
snake1,
|
||||
conv1,
|
||||
snake2,
|
||||
conv2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for ResidualUnit {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let ys = xs
|
||||
.apply(&self.snake1)?
|
||||
.apply(&self.conv1)?
|
||||
.apply(&self.snake2)?
|
||||
.apply(&self.conv2)?;
|
||||
let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2;
|
||||
if pad > 0 {
|
||||
&ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?)
|
||||
} else {
|
||||
ys + xs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EncoderBlock {
|
||||
res1: ResidualUnit,
|
||||
res2: ResidualUnit,
|
||||
res3: ResidualUnit,
|
||||
snake1: Snake1d,
|
||||
conv1: Conv1d,
|
||||
}
|
||||
|
||||
impl EncoderBlock {
|
||||
pub fn new(dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let res1 = ResidualUnit::new(dim / 2, 1, vb.pp(0))?;
|
||||
let res2 = ResidualUnit::new(dim / 2, 3, vb.pp(1))?;
|
||||
let res3 = ResidualUnit::new(dim / 2, 9, vb.pp(2))?;
|
||||
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
||||
let cfg1 = Conv1dConfig {
|
||||
stride,
|
||||
padding: (stride + 1) / 2,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
||||
Ok(Self {
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
snake1,
|
||||
conv1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for EncoderBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.res1)?
|
||||
.apply(&self.res2)?
|
||||
.apply(&self.res3)?
|
||||
.apply(&self.snake1)?
|
||||
.apply(&self.conv1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Encoder {
|
||||
conv1: Conv1d,
|
||||
blocks: Vec<EncoderBlock>,
|
||||
snake1: Snake1d,
|
||||
conv2: Conv1d,
|
||||
}
|
||||
|
||||
impl candle::Module for Encoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.conv1)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
xs.apply(&self.snake1)?.apply(&self.conv2)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
pub fn new(
|
||||
mut d_model: usize,
|
||||
strides: &[usize],
|
||||
d_latent: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let cfg1 = Conv1dConfig {
|
||||
padding: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = encodec::conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(0))?;
|
||||
let mut blocks = Vec::with_capacity(strides.len());
|
||||
for (block_idx, stride) in strides.iter().enumerate() {
|
||||
d_model *= 2;
|
||||
let block = EncoderBlock::new(d_model, *stride, vb.pp(block_idx + 1))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let snake1 = Snake1d::new(d_model, vb.pp(strides.len() + 1))?;
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv2 =
|
||||
encodec::conv1d_weight_norm(d_model, d_latent, 3, cfg2, vb.pp(strides.len() + 2))?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
blocks,
|
||||
snake1,
|
||||
conv2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DecoderBlock {
|
||||
snake1: Snake1d,
|
||||
conv_tr1: ConvTranspose1d,
|
||||
res1: ResidualUnit,
|
||||
res2: ResidualUnit,
|
||||
res3: ResidualUnit,
|
||||
}
|
||||
|
||||
impl DecoderBlock {
|
||||
pub fn new(in_dim: usize, out_dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
||||
let cfg = ConvTranspose1dConfig {
|
||||
stride,
|
||||
padding: (stride + 1) / 2,
|
||||
..Default::default()
|
||||
};
|
||||
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
||||
in_dim,
|
||||
out_dim,
|
||||
2 * stride,
|
||||
true,
|
||||
cfg,
|
||||
vb.pp(1),
|
||||
)?;
|
||||
let res1 = ResidualUnit::new(out_dim, 1, vb.pp(2))?;
|
||||
let res2 = ResidualUnit::new(out_dim, 3, vb.pp(3))?;
|
||||
let res3 = ResidualUnit::new(out_dim, 9, vb.pp(4))?;
|
||||
Ok(Self {
|
||||
snake1,
|
||||
conv_tr1,
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle_nn::Module for DecoderBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.snake1)?
|
||||
.apply(&self.conv_tr1)?
|
||||
.apply(&self.res1)?
|
||||
.apply(&self.res2)?
|
||||
.apply(&self.res3)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Decoder {
|
||||
conv1: Conv1d,
|
||||
blocks: Vec<DecoderBlock>,
|
||||
snake1: Snake1d,
|
||||
conv2: Conv1d,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
pub fn new(
|
||||
in_c: usize,
|
||||
mut channels: usize,
|
||||
rates: &[usize],
|
||||
d_out: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("model");
|
||||
let cfg1 = Conv1dConfig {
|
||||
padding: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = encodec::conv1d_weight_norm(in_c, channels, 7, cfg1, vb.pp(0))?;
|
||||
let mut blocks = Vec::with_capacity(rates.len());
|
||||
for (idx, stride) in rates.iter().enumerate() {
|
||||
let block = DecoderBlock::new(channels, channels / 2, *stride, vb.pp(idx + 1))?;
|
||||
channels /= 2;
|
||||
blocks.push(block)
|
||||
}
|
||||
let snake1 = Snake1d::new(channels, vb.pp(rates.len() + 1))?;
|
||||
let conv2 = encodec::conv1d_weight_norm(channels, d_out, 7, cfg1, vb.pp(rates.len() + 2))?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
blocks,
|
||||
snake1,
|
||||
conv2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for Decoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.conv1)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
xs.apply(&self.snake1)?.apply(&self.conv2)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VectorQuantizer {
|
||||
in_proj: Conv1d,
|
||||
out_proj: Conv1d,
|
||||
codebook: candle_nn::Embedding,
|
||||
}
|
||||
|
||||
impl VectorQuantizer {
|
||||
pub fn new(in_dim: usize, cb_size: usize, cb_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let in_proj =
|
||||
encodec::conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?;
|
||||
let out_proj =
|
||||
encodec::conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?;
|
||||
let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?;
|
||||
Ok(Self {
|
||||
in_proj,
|
||||
out_proj,
|
||||
codebook,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embed_code(&self, embed_id: &Tensor) -> Result<Tensor> {
|
||||
embed_id.apply(&self.codebook)
|
||||
}
|
||||
|
||||
pub fn decode_code(&self, embed_id: &Tensor) -> Result<Tensor> {
|
||||
self.embed_code(embed_id)?.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ResidualVectorQuantizer {
|
||||
quantizers: Vec<VectorQuantizer>,
|
||||
}
|
||||
|
||||
impl ResidualVectorQuantizer {
|
||||
pub fn new(
|
||||
input_dim: usize,
|
||||
n_codebooks: usize,
|
||||
cb_size: usize,
|
||||
cb_dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = &vb.pp("quantizers");
|
||||
let quantizers = (0..n_codebooks)
|
||||
.map(|i| VectorQuantizer::new(input_dim, cb_size, cb_dim, vb.pp(i)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { quantizers })
|
||||
}
|
||||
|
||||
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut sum = None;
|
||||
for (idx, quantizer) in self.quantizers.iter().enumerate() {
|
||||
let z_p_i = quantizer.decode_code(&codes.i((.., idx))?)?;
|
||||
let z_q_i = z_p_i.apply(&quantizer.out_proj)?;
|
||||
let s = match sum {
|
||||
None => z_q_i,
|
||||
Some(s) => (s + z_q_i)?,
|
||||
};
|
||||
sum = Some(s)
|
||||
}
|
||||
match sum {
|
||||
Some(s) => Ok(s),
|
||||
None => candle::bail!("empty codebooks"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
pub encoder: Encoder,
|
||||
pub quantizer: ResidualVectorQuantizer,
|
||||
pub decoder: Decoder,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("model");
|
||||
let encoder = Encoder::new(64, &[2, 4, 8, 8], cfg.latent_dim, vb.pp("encoder"))?;
|
||||
let quantizer = ResidualVectorQuantizer::new(
|
||||
cfg.latent_dim,
|
||||
cfg.num_codebooks,
|
||||
cfg.codebook_size,
|
||||
8,
|
||||
vb.pp("quantizer"),
|
||||
)?;
|
||||
let decoder = Decoder::new(cfg.latent_dim, 1536, &[8, 8, 4, 2], 1, vb.pp("decoder"))?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
quantizer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn decode_codes(&self, audio_codes: &Tensor) -> Result<Tensor> {
|
||||
let audio_values = self.quantizer.from_codes(audio_codes)?;
|
||||
audio_values.apply(&self.decoder)
|
||||
}
|
||||
}
|
@ -136,7 +136,7 @@ pub fn conv1d_weight_norm(
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
fn conv_transpose1d_weight_norm(
|
||||
pub fn conv_transpose1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
|
@ -9,6 +9,7 @@ pub mod clip;
|
||||
pub mod codegeex4_9b;
|
||||
pub mod convmixer;
|
||||
pub mod convnext;
|
||||
pub mod dac;
|
||||
pub mod depth_anything_v2;
|
||||
pub mod dinov2;
|
||||
pub mod dinov2reg4;
|
||||
|
@ -31,6 +31,7 @@ pub struct Config {
|
||||
pub decoder: DecoderConfig,
|
||||
pub text_encoder: t5::Config,
|
||||
pub vocab_size: usize,
|
||||
pub audio_encoder: crate::models::dac::Config,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -325,6 +326,7 @@ pub struct Model {
|
||||
pub text_encoder: t5::T5EncoderModel,
|
||||
pub decoder_start_token_id: u32,
|
||||
pub pad_token_id: u32,
|
||||
pub audio_encoder: crate::models::dac::Model,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -347,6 +349,8 @@ impl Model {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let audio_encoder =
|
||||
crate::models::dac::Model::new(&cfg.audio_encoder, vb.pp("audio_encoder"))?;
|
||||
Ok(Self {
|
||||
decoder,
|
||||
text_encoder,
|
||||
@ -354,6 +358,7 @@ impl Model {
|
||||
enc_to_dec_proj,
|
||||
decoder_start_token_id: cfg.decoder_start_token_id,
|
||||
pad_token_id: cfg.pad_token_id,
|
||||
audio_encoder,
|
||||
})
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user