parler-tts support (#2431)

* Start sketching parler-tts support.

* Implement the attention.

* Add the example code.

* Fix the example.

* Add the description + t5 encode it.

* More of the parler forward pass.

* Fix the positional embeddings.

* Support random sampling in generation.

* Handle EOS.

* Add the python decoder.

* Proper causality mask.
This commit is contained in:
Laurent Mazare
2024-08-18 19:42:08 +01:00
committed by GitHub
parent 736d8eb752
commit 58197e1896
5 changed files with 658 additions and 0 deletions

1
.gitignore vendored
View File

@ -41,3 +41,4 @@ candle-wasm-examples/**/config*.json
.DS_Store
.idea/*
__pycache__
out.safetensors

View File

@ -0,0 +1,29 @@
import torch
import torchaudio
from safetensors.torch import load_file
from parler_tts import DACModel
tensors = load_file("out.safetensors")
dac_model = DACModel.from_pretrained("parler-tts/dac_44khZ_8kbps")
output_ids = tensors["codes"][None, None]
print(output_ids, "\n", output_ids.shape)
batch_size = 1
with torch.no_grad():
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id]
sample_mask = (sample >= dac_model.config.codebook_size).sum(dim=(0, 1)) == 0
if sample_mask.sum() > 0:
sample = sample[:, :, sample_mask]
sample = dac_model.decode(sample[None, ...], [None]).audio_values
output_values.append(sample.transpose(0, 2))
else:
output_values.append(torch.zeros((1, 1, 1)).to(dac_model.device))
output_lengths = [audio.shape[0] for audio in output_values]
pcm = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1)
.squeeze(-1)
)
print(pcm.shape, pcm.dtype)
torchaudio.save("out.wav", pcm.cpu(), sample_rate=44100)

View File

@ -0,0 +1,175 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::parler_tts::{Config, Model};
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long, default_value = "Hey, how are you doing today?")]
prompt: String,
#[arg(
long,
default_value = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
)]
description: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 1.0)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 0)]
seed: u64,
#[arg(long, default_value_t = 5000)]
sample_len: usize,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
quantized: bool,
/// Use f16 precision for all the computations rather than f32.
#[arg(long)]
f16: bool,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long, default_value_t = 512)]
max_steps: usize,
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let start = std::time::Instant::now();
let api = hf_hub::api::sync::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "parler-tts/parler-tts-large-v1".to_string(),
};
let revision = match args.revision {
Some(r) => r,
None => "main".to_string(),
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
revision,
));
let model_files = match args.model_file {
Some(m) => vec![m.into()],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
let config = match args.config_file {
Some(m) => m.into(),
None => repo.get("config.json")?,
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device)? };
let config: Config = serde_json::from_reader(std::fs::File::open(config)?)?;
let mut model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let description_tokens = tokenizer
.encode(args.description, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
println!("{description_tokens}");
let prompt_tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
println!("{prompt_tokens}");
let lp = candle_transformers::generation::LogitsProcessor::new(
args.seed,
Some(args.temperature),
args.top_p,
);
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
println!("{codes}");
let codes = codes.to_dtype(DType::I64)?;
codes.save_safetensors("codes", "out.safetensors")?;
Ok(())
}

View File

@ -40,6 +40,7 @@ pub mod mobileone;
pub mod moondream;
pub mod mpt;
pub mod olmo;
pub mod parler_tts;
pub mod persimmon;
pub mod phi;
pub mod phi3;

View File

@ -0,0 +1,452 @@
use crate::generation::LogitsProcessor;
use crate::models::t5;
use candle::{IndexOp, Result, Tensor};
use candle_nn::{layer_norm, linear_b as linear, Activation, LayerNorm, Linear, VarBuilder};
#[derive(serde::Deserialize, Debug, Clone)]
pub struct DecoderConfig {
pub vocab_size: usize,
pub max_position_embeddings: usize,
pub num_hidden_layers: usize,
pub ffn_dim: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: Option<usize>,
pub num_cross_attention_key_value_heads: Option<usize>,
pub activation_function: Activation,
pub hidden_size: usize,
pub scale_embedding: bool,
pub num_codebooks: usize,
pub pad_token_id: usize,
pub bos_token_id: usize,
pub eos_token_id: usize,
pub tie_word_embeddings: bool,
pub rope_embeddings: bool,
pub rope_theta: f64,
}
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub decoder_start_token_id: u32,
pub pad_token_id: u32,
pub decoder: DecoderConfig,
pub text_encoder: t5::Config,
pub vocab_size: usize,
}
#[derive(Debug, Clone)]
pub struct Attention {
k_proj: Linear,
v_proj: Linear,
q_proj: Linear,
out_proj: Linear,
is_causal: bool,
kv_cache: Option<(Tensor, Tensor)>,
scaling: f64,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
}
impl Attention {
fn new(
num_kv_heads: usize,
is_causal: bool,
cfg: &DecoderConfig,
vb: VarBuilder,
) -> Result<Self> {
if cfg.rope_embeddings {
candle::bail!("rope embeddings are not supported");
}
let embed_dim = cfg.hidden_size;
let head_dim = embed_dim / cfg.num_attention_heads;
let kv_out_dim = num_kv_heads * head_dim;
let k_proj = linear(embed_dim, kv_out_dim, false, vb.pp("k_proj"))?;
let v_proj = linear(embed_dim, kv_out_dim, false, vb.pp("v_proj"))?;
let q_proj = linear(embed_dim, embed_dim, false, vb.pp("q_proj"))?;
let out_proj = linear(embed_dim, embed_dim, false, vb.pp("out_proj"))?;
Ok(Self {
k_proj,
v_proj,
q_proj,
out_proj,
is_causal,
kv_cache: None,
scaling: (head_dim as f64).powf(-0.5),
num_heads: cfg.num_attention_heads,
num_kv_heads,
num_kv_groups: cfg.num_attention_heads / num_kv_heads,
head_dim,
})
}
fn forward(
&mut self,
xs: &Tensor,
key_value_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?
.reshape((b_sz, tgt_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let key_states = match key_value_states {
Some(states) => states.apply(&self.k_proj)?,
None => xs.apply(&self.k_proj)?,
};
let key_states = key_states
.reshape((b_sz, (), self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let value_states = match key_value_states {
Some(states) => states.apply(&self.v_proj)?,
None => xs.apply(&self.v_proj)?,
};
let value_states = value_states
.reshape((b_sz, (), self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
if self.is_causal {
self.kv_cache = Some((key_states.clone(), value_states.clone()));
}
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&value_states)?;
attn_output
.transpose(1, 2)?
.reshape((b_sz, tgt_len, ()))?
.apply(&self.out_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
pub struct DecoderLayer {
self_attn: Attention,
self_attn_layer_norm: LayerNorm,
encoder_attn: Attention,
encoder_attn_layer_norm: LayerNorm,
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
activation: Activation,
}
impl DecoderLayer {
fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> {
let kv_heads = cfg.num_key_value_heads.unwrap_or(cfg.num_attention_heads);
let kv_heads_cross = cfg.num_cross_attention_key_value_heads.unwrap_or(kv_heads);
let self_attn = Attention::new(kv_heads, true, cfg, vb.pp("self_attn"))?;
let encoder_attn = Attention::new(kv_heads_cross, false, cfg, vb.pp("encoder_attn"))?;
let self_attn_layer_norm =
layer_norm(cfg.hidden_size, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn_layer_norm =
layer_norm(cfg.hidden_size, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear(cfg.hidden_size, cfg.ffn_dim, false, vb.pp("fc1"))?;
let fc2 = linear(cfg.ffn_dim, cfg.hidden_size, false, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb.pp("final_layer_norm"))?;
Ok(Self {
self_attn,
self_attn_layer_norm,
encoder_attn,
encoder_attn_layer_norm,
fc1,
fc2,
final_layer_norm,
activation: cfg.activation_function,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
encoder_xs: &Tensor,
encoder_attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
// Self attention
let residual = xs;
let xs = xs.apply(&self.self_attn_layer_norm)?;
let xs = self.self_attn.forward(&xs, None, attention_mask)?;
let xs = (residual + xs)?;
// Cross attention
let residual = &xs;
let xs = xs.apply(&self.encoder_attn_layer_norm)?;
let xs = self
.encoder_attn
.forward(&xs, Some(encoder_xs), encoder_attention_mask)?;
let xs = (residual + xs)?;
// Fully connected
let residual = &xs;
let xs = xs
.apply(&self.final_layer_norm)?
.apply(&self.fc1)?
.apply(&self.activation)?
.apply(&self.fc2)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
self.encoder_attn.clear_kv_cache();
}
}
#[derive(Debug, Clone)]
pub struct Decoder {
embed_tokens: Vec<candle_nn::Embedding>,
embed_positions: Tensor,
layers: Vec<DecoderLayer>,
layer_norm: LayerNorm,
num_codebooks: usize,
hidden_size: usize,
lm_heads: Vec<Linear>,
dtype: candle::DType,
}
impl Decoder {
pub fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> {
let vb_d = vb.pp("model.decoder");
let mut embed_tokens = Vec::with_capacity(cfg.num_codebooks);
let vb_e = vb_d.pp("embed_tokens");
for embed_idx in 0..cfg.num_codebooks {
let e = candle_nn::embedding(cfg.vocab_size + 1, cfg.hidden_size, vb_e.pp(embed_idx))?;
embed_tokens.push(e)
}
let embed_positions = vb_d.get(
(cfg.max_position_embeddings, cfg.hidden_size),
"embed_positions.weights",
)?;
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_d.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb_d.pp("layer_norm"))?;
let mut lm_heads = Vec::with_capacity(cfg.num_codebooks);
let vb_l = vb.pp("lm_heads");
for lm_idx in 0..cfg.num_codebooks {
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb_l.pp(lm_idx))?;
lm_heads.push(lm_head)
}
Ok(Self {
embed_tokens,
embed_positions,
layers,
layer_norm,
num_codebooks: cfg.num_codebooks,
lm_heads,
hidden_size: cfg.hidden_size,
dtype: vb.dtype(),
})
}
pub fn forward(
&mut self,
input_ids: &Tensor,
prompt_hidden_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_xs: &Tensor,
encoder_attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Vec<Tensor>> {
let (b_sz, num_codebooks, seq_len) = input_ids.dims3()?;
if num_codebooks != self.num_codebooks {
candle::bail!("unexpected num codebooks in input {:?}", input_ids.shape())
}
let mut inputs_embeds = Tensor::zeros(
(b_sz, seq_len, self.hidden_size),
self.dtype,
input_ids.device(),
)?;
for (idx, embs) in self.embed_tokens.iter().enumerate() {
let e = input_ids.i((.., idx))?.apply(embs)?;
inputs_embeds = (inputs_embeds + e)?
}
let inputs_embeds = match prompt_hidden_states {
None => inputs_embeds,
Some(pis) => Tensor::cat(&[pis, &inputs_embeds], 1)?,
};
let embed_positions = self
.embed_positions
.i(seqlen_offset..seqlen_offset + inputs_embeds.dim(1)?)?;
let mut xs = (inputs_embeds + embed_positions.unsqueeze(0))?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask, encoder_xs, encoder_attention_mask)?;
}
let xs = xs.apply(&self.layer_norm)?;
let mut lm_logits = Vec::with_capacity(self.num_codebooks);
for lm_head in self.lm_heads.iter() {
let logits = xs.apply(lm_head)?;
lm_logits.push(logits)
}
Ok(lm_logits)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}
#[derive(Debug, Clone)]
pub struct Model {
pub embed_prompts: candle_nn::Embedding,
pub enc_to_dec_proj: Option<Linear>,
pub decoder: Decoder,
pub text_encoder: t5::T5EncoderModel,
pub decoder_start_token_id: u32,
pub pad_token_id: u32,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.text_encoder)?;
let decoder = Decoder::new(&cfg.decoder, vb.pp("decoder"))?;
let embed_prompts = candle_nn::embedding(
cfg.vocab_size,
cfg.decoder.hidden_size,
vb.pp("embed_prompts"),
)?;
let enc_to_dec_proj = if cfg.text_encoder.d_model != cfg.decoder.hidden_size {
let proj = linear(
cfg.text_encoder.d_model,
cfg.decoder.hidden_size,
true,
vb.pp("enc_to_dec_proj"),
)?;
Some(proj)
} else {
None
};
Ok(Self {
decoder,
text_encoder,
embed_prompts,
enc_to_dec_proj,
decoder_start_token_id: cfg.decoder_start_token_id,
pad_token_id: cfg.pad_token_id,
})
}
/// Note that the returned tensor uses the CPU device.
pub fn generate(
&mut self,
prompt_tokens: &Tensor,
description_tokens: &Tensor,
mut lp: LogitsProcessor,
max_steps: usize,
) -> Result<Tensor> {
self.decoder.clear_kv_cache();
self.text_encoder.clear_kv_cache();
let encoded = self.text_encoder.forward(description_tokens)?;
let encoded = match self.enc_to_dec_proj.as_ref() {
None => encoded,
Some(proj) => encoded.apply(proj)?,
};
let prompt_hidden_states = prompt_tokens.apply(&self.embed_prompts)?;
let num_codebooks = self.decoder.num_codebooks;
let mut audio_tokens = vec![self.decoder_start_token_id; num_codebooks];
let mut all_audio_tokens = vec![vec![]; num_codebooks];
let prompt_len = prompt_hidden_states.dim(1)?;
for step in 0..max_steps {
let input_ids = Tensor::from_slice(
audio_tokens.as_slice(),
(1, num_codebooks, 1),
prompt_tokens.device(),
)?;
let (prompt_hidden_states, pos) = if step == 0 {
(Some(&prompt_hidden_states), 0)
} else {
(None, step + prompt_len)
};
let causal_mask = if pos == 0 {
self.prepare_causal_mask(prompt_len + 1, prompt_len + 1, input_ids.device())?
} else {
self.prepare_causal_mask(1, pos + 1, input_ids.device())?
};
let logits = self.decoder.forward(
&input_ids,
prompt_hidden_states,
Some(&causal_mask),
&encoded,
None,
pos,
)?;
for (logit_idx, logit) in logits.iter().enumerate() {
if logit_idx > step {
break;
}
if audio_tokens[logit_idx] != self.pad_token_id {
let logit = logit.i((0, logit.dim(1)? - 1))?;
let token = lp.sample(&logit)?;
audio_tokens[logit_idx] = token
}
}
if audio_tokens.iter().all(|v| v == &self.pad_token_id) {
break;
}
for (cb_idx, &token) in audio_tokens.iter().enumerate() {
if token != self.decoder_start_token_id && token != self.pad_token_id {
all_audio_tokens[cb_idx].push(token)
}
}
}
let min_len = all_audio_tokens.iter().map(|v| v.len()).min().unwrap_or(0);
all_audio_tokens.iter_mut().for_each(|v| {
v.resize(min_len, 0);
v.push(self.pad_token_id)
});
let all_audio_tokens = Tensor::new(all_audio_tokens, &candle::Device::Cpu)?;
Ok(all_audio_tokens)
}
fn prepare_causal_mask(
&self,
q_len: usize,
kv_len: usize,
device: &candle::Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..q_len)
.flat_map(|i| {
(0..kv_len).map(move |j| {
if i + kv_len < j + q_len {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
Tensor::from_slice(&mask, (q_len, kv_len), device)
}
}