diff --git a/.gitignore b/.gitignore index 06859efb..38a7d504 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,4 @@ candle-wasm-examples/**/config*.json .idea/* __pycache__ out.safetensors +out.wav diff --git a/candle-examples/examples/parler-tts/decode.py b/candle-examples/examples/parler-tts/decode.py index b79ebda1..8942d32e 100644 --- a/candle-examples/examples/parler-tts/decode.py +++ b/candle-examples/examples/parler-tts/decode.py @@ -5,6 +5,7 @@ from parler_tts import DACModel tensors = load_file("out.safetensors") dac_model = DACModel.from_pretrained("parler-tts/dac_44khZ_8kbps") +print(dac_model.model) output_ids = tensors["codes"][None, None] print(output_ids, "\n", output_ids.shape) batch_size = 1 diff --git a/candle-examples/examples/parler-tts/main.rs b/candle-examples/examples/parler-tts/main.rs index e6cf3c44..4e3730e2 100644 --- a/candle-examples/examples/parler-tts/main.rs +++ b/candle-examples/examples/parler-tts/main.rs @@ -7,7 +7,7 @@ extern crate accelerate_src; use anyhow::Error as E; use clap::Parser; -use candle::{DType, Tensor}; +use candle::{DType, IndexOp, Tensor}; use candle_nn::VarBuilder; use candle_transformers::models::parler_tts::{Config, Model}; use tokenizers::Tokenizer; @@ -36,7 +36,7 @@ struct Args { description: String, /// The temperature used to generate samples. - #[arg(long, default_value_t = 1.0)] + #[arg(long, default_value_t = 0.0)] temperature: f64, /// Nucleus sampling probability cutoff. @@ -82,6 +82,10 @@ struct Args { #[arg(long, default_value_t = 512)] max_steps: usize, + + /// The output wav file. + #[arg(long, default_value = "out.wav")] + out_file: String, } fn main() -> anyhow::Result<()> { @@ -152,24 +156,32 @@ fn main() -> anyhow::Result<()> { .get_ids() .to_vec(); let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?; - println!("{description_tokens}"); - let prompt_tokens = tokenizer .encode(args.prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?; - println!("{prompt_tokens}"); - let lp = candle_transformers::generation::LogitsProcessor::new( args.seed, Some(args.temperature), args.top_p, ); + println!("starting generation..."); let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?; - println!("{codes}"); + println!("generated codes\n{codes}"); let codes = codes.to_dtype(DType::I64)?; codes.save_safetensors("codes", "out.safetensors")?; + let codes = codes.unsqueeze(0)?; + let pcm = model + .audio_encoder + .decode_codes(&codes.to_device(&device)?)?; + println!("{pcm}"); + let pcm = pcm.i((0, 0))?; + let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?; + let pcm = pcm.to_vec1::()?; + let mut output = std::fs::File::create(&args.out_file)?; + candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, config.audio_encoder.sampling_rate)?; + Ok(()) } diff --git a/candle-transformers/src/models/dac.rs b/candle-transformers/src/models/dac.rs new file mode 100644 index 00000000..fa6c8c71 --- /dev/null +++ b/candle-transformers/src/models/dac.rs @@ -0,0 +1,376 @@ +/// Adapted from https://github.com/descriptinc/descript-audio-codec +use crate::models::encodec; +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, VarBuilder}; + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + pub num_codebooks: usize, + pub model_bitrate: u32, + pub codebook_size: usize, + pub latent_dim: usize, + pub frame_rate: u32, + pub sampling_rate: u32, +} + +#[derive(Debug, Clone)] +pub struct Snake1d { + alpha: Tensor, +} + +impl Snake1d { + pub fn new(channels: usize, vb: VarBuilder) -> Result { + let alpha = vb.get((1, channels, 1), "alpha")?; + Ok(Self { alpha }) + } +} + +impl candle::Module for Snake1d { + fn forward(&self, xs: &Tensor) -> Result { + let xs_shape = xs.shape(); + let xs = xs.flatten_from(2)?; + let sin = self.alpha.broadcast_mul(&xs)?.sin()?; + let sin = (&sin * &sin)?; + (xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape) + } +} + +#[derive(Debug, Clone)] +pub struct ResidualUnit { + snake1: Snake1d, + conv1: Conv1d, + snake2: Snake1d, + conv2: Conv1d, +} + +impl ResidualUnit { + pub fn new(dim: usize, dilation: usize, vb: VarBuilder) -> Result { + let pad = ((7 - 1) * dilation) / 2; + let vb = vb.pp("block"); + let snake1 = Snake1d::new(dim, vb.pp(0))?; + let cfg1 = Conv1dConfig { + dilation, + padding: pad, + ..Default::default() + }; + let conv1 = encodec::conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?; + let snake2 = Snake1d::new(dim, vb.pp(2))?; + let conv2 = encodec::conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?; + Ok(Self { + snake1, + conv1, + snake2, + conv2, + }) + } +} + +impl candle::Module for ResidualUnit { + fn forward(&self, xs: &Tensor) -> Result { + let ys = xs + .apply(&self.snake1)? + .apply(&self.conv1)? + .apply(&self.snake2)? + .apply(&self.conv2)?; + let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2; + if pad > 0 { + &ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?) + } else { + ys + xs + } + } +} + +#[derive(Debug, Clone)] +pub struct EncoderBlock { + res1: ResidualUnit, + res2: ResidualUnit, + res3: ResidualUnit, + snake1: Snake1d, + conv1: Conv1d, +} + +impl EncoderBlock { + pub fn new(dim: usize, stride: usize, vb: VarBuilder) -> Result { + let vb = vb.pp("block"); + let res1 = ResidualUnit::new(dim / 2, 1, vb.pp(0))?; + let res2 = ResidualUnit::new(dim / 2, 3, vb.pp(1))?; + let res3 = ResidualUnit::new(dim / 2, 9, vb.pp(2))?; + let snake1 = Snake1d::new(dim / 2, vb.pp(3))?; + let cfg1 = Conv1dConfig { + stride, + padding: (stride + 1) / 2, + ..Default::default() + }; + let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?; + Ok(Self { + res1, + res2, + res3, + snake1, + conv1, + }) + } +} + +impl candle::Module for EncoderBlock { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.res1)? + .apply(&self.res2)? + .apply(&self.res3)? + .apply(&self.snake1)? + .apply(&self.conv1) + } +} + +#[derive(Debug, Clone)] +pub struct Encoder { + conv1: Conv1d, + blocks: Vec, + snake1: Snake1d, + conv2: Conv1d, +} + +impl candle::Module for Encoder { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.apply(&self.conv1)?; + for block in self.blocks.iter() { + xs = xs.apply(block)? + } + xs.apply(&self.snake1)?.apply(&self.conv2) + } +} + +impl Encoder { + pub fn new( + mut d_model: usize, + strides: &[usize], + d_latent: usize, + vb: VarBuilder, + ) -> Result { + let vb = vb.pp("block"); + let cfg1 = Conv1dConfig { + padding: 3, + ..Default::default() + }; + let conv1 = encodec::conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(0))?; + let mut blocks = Vec::with_capacity(strides.len()); + for (block_idx, stride) in strides.iter().enumerate() { + d_model *= 2; + let block = EncoderBlock::new(d_model, *stride, vb.pp(block_idx + 1))?; + blocks.push(block) + } + let snake1 = Snake1d::new(d_model, vb.pp(strides.len() + 1))?; + let cfg2 = Conv1dConfig { + padding: 1, + ..Default::default() + }; + let conv2 = + encodec::conv1d_weight_norm(d_model, d_latent, 3, cfg2, vb.pp(strides.len() + 2))?; + Ok(Self { + conv1, + blocks, + snake1, + conv2, + }) + } +} + +#[derive(Debug, Clone)] +pub struct DecoderBlock { + snake1: Snake1d, + conv_tr1: ConvTranspose1d, + res1: ResidualUnit, + res2: ResidualUnit, + res3: ResidualUnit, +} + +impl DecoderBlock { + pub fn new(in_dim: usize, out_dim: usize, stride: usize, vb: VarBuilder) -> Result { + let vb = vb.pp("block"); + let snake1 = Snake1d::new(in_dim, vb.pp(0))?; + let cfg = ConvTranspose1dConfig { + stride, + padding: (stride + 1) / 2, + ..Default::default() + }; + let conv_tr1 = encodec::conv_transpose1d_weight_norm( + in_dim, + out_dim, + 2 * stride, + true, + cfg, + vb.pp(1), + )?; + let res1 = ResidualUnit::new(out_dim, 1, vb.pp(2))?; + let res2 = ResidualUnit::new(out_dim, 3, vb.pp(3))?; + let res3 = ResidualUnit::new(out_dim, 9, vb.pp(4))?; + Ok(Self { + snake1, + conv_tr1, + res1, + res2, + res3, + }) + } +} + +impl candle_nn::Module for DecoderBlock { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.snake1)? + .apply(&self.conv_tr1)? + .apply(&self.res1)? + .apply(&self.res2)? + .apply(&self.res3) + } +} + +#[derive(Debug, Clone)] +pub struct Decoder { + conv1: Conv1d, + blocks: Vec, + snake1: Snake1d, + conv2: Conv1d, +} + +impl Decoder { + pub fn new( + in_c: usize, + mut channels: usize, + rates: &[usize], + d_out: usize, + vb: VarBuilder, + ) -> Result { + let vb = vb.pp("model"); + let cfg1 = Conv1dConfig { + padding: 3, + ..Default::default() + }; + let conv1 = encodec::conv1d_weight_norm(in_c, channels, 7, cfg1, vb.pp(0))?; + let mut blocks = Vec::with_capacity(rates.len()); + for (idx, stride) in rates.iter().enumerate() { + let block = DecoderBlock::new(channels, channels / 2, *stride, vb.pp(idx + 1))?; + channels /= 2; + blocks.push(block) + } + let snake1 = Snake1d::new(channels, vb.pp(rates.len() + 1))?; + let conv2 = encodec::conv1d_weight_norm(channels, d_out, 7, cfg1, vb.pp(rates.len() + 2))?; + Ok(Self { + conv1, + blocks, + snake1, + conv2, + }) + } +} + +impl candle::Module for Decoder { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.apply(&self.conv1)?; + for block in self.blocks.iter() { + xs = xs.apply(block)? + } + xs.apply(&self.snake1)?.apply(&self.conv2) + } +} + +#[allow(unused)] +#[derive(Clone, Debug)] +pub struct VectorQuantizer { + in_proj: Conv1d, + out_proj: Conv1d, + codebook: candle_nn::Embedding, +} + +impl VectorQuantizer { + pub fn new(in_dim: usize, cb_size: usize, cb_dim: usize, vb: VarBuilder) -> Result { + let in_proj = + encodec::conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?; + let out_proj = + encodec::conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?; + let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?; + Ok(Self { + in_proj, + out_proj, + codebook, + }) + } + + pub fn embed_code(&self, embed_id: &Tensor) -> Result { + embed_id.apply(&self.codebook) + } + + pub fn decode_code(&self, embed_id: &Tensor) -> Result { + self.embed_code(embed_id)?.transpose(1, 2) + } +} + +#[derive(Clone, Debug)] +pub struct ResidualVectorQuantizer { + quantizers: Vec, +} + +impl ResidualVectorQuantizer { + pub fn new( + input_dim: usize, + n_codebooks: usize, + cb_size: usize, + cb_dim: usize, + vb: VarBuilder, + ) -> Result { + let vb = &vb.pp("quantizers"); + let quantizers = (0..n_codebooks) + .map(|i| VectorQuantizer::new(input_dim, cb_size, cb_dim, vb.pp(i))) + .collect::>>()?; + Ok(Self { quantizers }) + } + + pub fn from_codes(&self, codes: &Tensor) -> Result { + let mut sum = None; + for (idx, quantizer) in self.quantizers.iter().enumerate() { + let z_p_i = quantizer.decode_code(&codes.i((.., idx))?)?; + let z_q_i = z_p_i.apply(&quantizer.out_proj)?; + let s = match sum { + None => z_q_i, + Some(s) => (s + z_q_i)?, + }; + sum = Some(s) + } + match sum { + Some(s) => Ok(s), + None => candle::bail!("empty codebooks"), + } + } +} + +#[derive(Debug, Clone)] +pub struct Model { + pub encoder: Encoder, + pub quantizer: ResidualVectorQuantizer, + pub decoder: Decoder, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb = vb.pp("model"); + let encoder = Encoder::new(64, &[2, 4, 8, 8], cfg.latent_dim, vb.pp("encoder"))?; + let quantizer = ResidualVectorQuantizer::new( + cfg.latent_dim, + cfg.num_codebooks, + cfg.codebook_size, + 8, + vb.pp("quantizer"), + )?; + let decoder = Decoder::new(cfg.latent_dim, 1536, &[8, 8, 4, 2], 1, vb.pp("decoder"))?; + Ok(Self { + encoder, + decoder, + quantizer, + }) + } + + pub fn decode_codes(&self, audio_codes: &Tensor) -> Result { + let audio_values = self.quantizer.from_codes(audio_codes)?; + audio_values.apply(&self.decoder) + } +} diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs index fb70fb52..ba6686f6 100644 --- a/candle-transformers/src/models/encodec.rs +++ b/candle-transformers/src/models/encodec.rs @@ -136,7 +136,7 @@ pub fn conv1d_weight_norm( Ok(Conv1d::new(weight, Some(bias), config)) } -fn conv_transpose1d_weight_norm( +pub fn conv_transpose1d_weight_norm( in_c: usize, out_c: usize, kernel_size: usize, diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 83d13a7b..cc83cf7b 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -9,6 +9,7 @@ pub mod clip; pub mod codegeex4_9b; pub mod convmixer; pub mod convnext; +pub mod dac; pub mod depth_anything_v2; pub mod dinov2; pub mod dinov2reg4; diff --git a/candle-transformers/src/models/parler_tts.rs b/candle-transformers/src/models/parler_tts.rs index 9c66c93a..16023a7c 100644 --- a/candle-transformers/src/models/parler_tts.rs +++ b/candle-transformers/src/models/parler_tts.rs @@ -31,6 +31,7 @@ pub struct Config { pub decoder: DecoderConfig, pub text_encoder: t5::Config, pub vocab_size: usize, + pub audio_encoder: crate::models::dac::Config, } #[derive(Debug, Clone)] @@ -325,6 +326,7 @@ pub struct Model { pub text_encoder: t5::T5EncoderModel, pub decoder_start_token_id: u32, pub pad_token_id: u32, + pub audio_encoder: crate::models::dac::Model, } impl Model { @@ -347,6 +349,8 @@ impl Model { } else { None }; + let audio_encoder = + crate::models::dac::Model::new(&cfg.audio_encoder, vb.pp("audio_encoder"))?; Ok(Self { decoder, text_encoder, @@ -354,6 +358,7 @@ impl Model { enc_to_dec_proj, decoder_start_token_id: cfg.decoder_start_token_id, pad_token_id: cfg.pad_token_id, + audio_encoder, }) }