From 4fff5b51f54f1b14216b7293ec9ca674bef2a904 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 2 Mar 2024 18:50:01 +0100 Subject: [PATCH] Metavoice - first cut (#1717) * Add the metavoice transformer. * Sketch the speaker-encoder module. * Adding to the metavoice model. * Start adding the metavoice example. * Get some logits out. * Load the second stage model. * Get the second step to run. * Tweak the example. * Add encodec tilting. * Glue the different bits together. * Fix a shape issue. * Use a constant. * BPE tokenization. * Add a warning. --- Cargo.toml | 1 + candle-examples/examples/metavoice/README.md | 18 + candle-examples/examples/metavoice/main.rs | 218 +++++ candle-transformers/Cargo.toml | 1 + candle-transformers/src/models/metavoice.rs | 878 +++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 6 files changed, 1117 insertions(+) create mode 100644 candle-examples/examples/metavoice/README.md create mode 100644 candle-examples/examples/metavoice/main.rs create mode 100644 candle-transformers/src/models/metavoice.rs diff --git a/Cargo.toml b/Cargo.toml index d591e1d7..40f51fea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ candle-transformers = { path = "./candle-transformers", version = "0.4.1" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.10.0", features = ["f16"] } +fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } diff --git a/candle-examples/examples/metavoice/README.md b/candle-examples/examples/metavoice/README.md new file mode 100644 index 00000000..ef53e66f --- /dev/null +++ b/candle-examples/examples/metavoice/README.md @@ -0,0 +1,18 @@ +# candle-metavoice + +MetaVoice-1B is a text-to-speech model trained on 100K hours of speech, more +details on the [model +card](https://huggingface.co/metavoiceio/metavoice-1B-v0.1). + +Note that the current candle implementation suffers from some limitations as of +2024-03-02: +- The speaker embeddings are hardcoded. +- The generated audio file quality is weaker than the Python implementation, + probably because of some implementation discrepancies. + +## Run an example + +```bash +cargo run --example metavoice --release -- \\ + --prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model." +``` diff --git a/candle-examples/examples/metavoice/main.rs b/candle-examples/examples/metavoice/main.rs new file mode 100644 index 00000000..6788976a --- /dev/null +++ b/candle-examples/examples/metavoice/main.rs @@ -0,0 +1,218 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use clap::Parser; + +use candle_transformers::generation::LogitsProcessor; +use candle_transformers::models::encodec; +use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer}; + +use candle::{DType, IndexOp, Tensor}; +use candle_nn::VarBuilder; +use hf_hub::api::sync::Api; +use rand::{distributions::Distribution, SeedableRng}; + +pub const ENCODEC_NTOKENS: u32 = 1024; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + prompt: String, + + /// The guidance scale. + #[arg(long, default_value_t = 3.0)] + guidance_scale: f64, + + /// The temperature used to generate samples. + #[arg(long, default_value_t = 1.0)] + temperature: f64, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The output file using the wav format. + #[arg(long, default_value = "out.wav")] + out_file: String, + + #[arg(long)] + first_stage_meta: Option, + + #[arg(long)] + first_stage_weights: Option, + + #[arg(long)] + second_stage_weights: Option, + + #[arg(long)] + encodec_weights: Option, + + #[arg(long)] + spk_emb: Option, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + let device = candle_examples::device(args.cpu)?; + let api = Api::new()?; + let repo = api.model("lmz/candle-metavoice".to_string()); + let first_stage_meta = match &args.first_stage_meta { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("first_stage.meta.json")?, + }; + let first_stage_meta: serde_json::Value = + serde_json::from_reader(&std::fs::File::open(first_stage_meta)?)?; + let first_stage_tokenizer = match first_stage_meta.as_object() { + None => anyhow::bail!("not a json object"), + Some(j) => match j.get("tokenizer") { + None => anyhow::bail!("no tokenizer key"), + Some(j) => j, + }, + }; + let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?; + + let first_stage_weights = match &args.first_stage_weights { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("first_stage.safetensors")?, + }; + let second_stage_weights = match &args.first_stage_weights { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("second_stage.safetensors")?, + }; + let encodec_weights = match args.encodec_weights { + Some(w) => std::path::PathBuf::from(w), + None => Api::new()? + .model("facebook/encodec_24khz".to_string()) + .get("model.safetensors")?, + }; + let first_stage_vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[first_stage_weights], DType::F32, &device)? + }; + let first_stage_config = transformer::Config::cfg1b_v0_1(); + let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?; + + let second_stage_vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[second_stage_weights], DType::F32, &device)? + }; + let second_stage_config = gpt::Config::cfg1b_v0_1(); + let second_stage_model = gpt::Model::new(second_stage_config, second_stage_vb)?; + + let encodec_vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], DType::F32, &device)? }; + let encodec_config = encodec::Config::default(); + let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?; + + println!("prompt: '{}'", args.prompt); + let prompt_tokens = fs_tokenizer.encode(&args.prompt)?; + let mut tokens = prompt_tokens.clone(); + println!("{tokens:?}"); + let spk_emb_file = match &args.spk_emb { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("spk_emb.safetensors")?, + }; + let spk_emb = candle::safetensors::load(&spk_emb_file, &device)?; + let spk_emb = match spk_emb.get("spk_emb") { + None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"), + Some(spk_emb) => spk_emb.to_dtype(DType::F32)?, + }; + let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), None); + + // First stage generation. + for index in 0.. { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let start_pos = tokens.len().saturating_sub(context_size); + let ctxt = &tokens[start_pos..]; + let input = Tensor::new(ctxt, &device)?; + let input = Tensor::stack(&[&input, &input], 0)?; + let logits = first_stage_model.forward(&input, &spk_emb, index)?; + let logits0 = logits.i((0, 0))?; + let logits1 = logits.i((1, 0))?; + let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?; + let logits = logits.to_dtype(DType::F32)?; + let next_token = logits_processor.sample(&logits)?; + println!("{} {next_token}", tokens.len()); + tokens.push(next_token); + if next_token == 2048 { + break; + } + } + let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS); + let (text_ids, ids1, ids2) = fie2c.decode(&tokens); + println!("text ids len: {}", text_ids.len()); + let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337); + // TODO: Use the config rather than hardcoding the offset here. + let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect(); + let hierarchies_in1 = [encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat(); + let hierarchies_in2 = [ + vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(), + ids2.as_slice(), + &[ENCODEC_NTOKENS], + ] + .concat(); + let in_x1 = Tensor::new(hierarchies_in1, &device)?; + let in_x2 = Tensor::new(hierarchies_in2, &device)?; + let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?; + let logits = second_stage_model.forward(&in_x)?; + let mut codes = vec![]; + for (idx, logits) in logits.iter().enumerate() { + println!("{idx} {logits}"); + let logits = logits.squeeze(0)?; + let (seq_len, _) = logits.dims2()?; + let mut codes_ = Vec::with_capacity(seq_len); + for step in 0..seq_len { + let logits = logits.i(step)?.to_dtype(DType::F32)?; + let logits = &(&logits / 1.0)?; + let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::()?; + let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?; + let sample = distr.sample(&mut rng) as u32; + codes_.push(sample) + } + codes.push(codes_) + } + + let codes = Tensor::new(codes, &device)?.unsqueeze(0)?; + let codes = Tensor::cat(&[in_x, codes], 1)?; + println!("codes: {codes}"); + let tilted_encodec = adapters::TiltedEncodec::new(ENCODEC_NTOKENS); + let codes = codes.i(0)?.to_vec2::()?; + let (text_ids, audio_ids) = tilted_encodec.decode(&codes); + println!("text_ids len: {:?}", text_ids.len()); + let audio_ids = Tensor::new(audio_ids, &device)?.unsqueeze(0)?; + println!("audio_ids shape: {:?}", audio_ids.shape()); + let pcm = encodec_model.decode(&audio_ids)?; + println!("output pcm shape: {:?}", pcm.shape()); + let pcm = pcm.i(0)?.i(0)?.to_vec1::()?; + let mut output = std::fs::File::create(&args.out_file)?; + candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; + Ok(()) +} diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 0e55ab8c..6589b4b1 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -15,6 +15,7 @@ byteorder = { workspace = true } candle = { workspace = true } candle-flash-attn = { workspace = true, optional = true } candle-nn = { workspace = true } +fancy-regex = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } rand = { workspace = true } diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs new file mode 100644 index 00000000..e37c168c --- /dev/null +++ b/candle-transformers/src/models/metavoice.rs @@ -0,0 +1,878 @@ +use candle::{DType, Error as E, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; + +// Equivalent to torch.repeat_interleave +fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result { + let img = img.unsqueeze(dim + 1)?; + let mut dims = img.dims().to_vec(); + dims[dim + 1] = repeats; + img.broadcast_as(dims)?.flatten(dim, dim + 1) +} +pub mod speaker_encoder { + use super::*; + + #[derive(Debug, Clone, serde::Deserialize)] + pub struct Config { + pub mel_window_step: usize, + pub mel_n_channels: usize, + pub sampling_rate: usize, + pub partial_n_frames: usize, + pub model_hidden_size: usize, + pub model_embedding_size: usize, + pub model_num_layers: usize, + } + + pub struct Model { + lstms: Vec, + linear: Linear, + } + + impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let mut lstms = Vec::with_capacity(cfg.model_num_layers); + let vb_l = vb.pp("lstm"); + for layer_idx in 0..cfg.model_num_layers { + let c = candle_nn::LSTMConfig { + layer_idx, + ..Default::default() + }; + let lstm = candle_nn::lstm( + cfg.mel_n_channels, + cfg.model_hidden_size, + c, + vb_l.pp(layer_idx), + )?; + lstms.push(lstm) + } + let linear = linear_b( + cfg.model_hidden_size, + cfg.model_embedding_size, + true, + vb.pp("linear"), + )?; + Ok(Self { lstms, linear }) + } + + fn compute_partial_slices( + _n_samples: usize, + _rate: f64, + _min_coverage: f64, + ) -> Result<(Tensor, Tensor)> { + todo!() + } + + pub fn embed_utterance(&self, wav: &[f32], rate: f64, min_coverage: f64) -> Result { + let (_wav_slices, _mel_slices) = + Self::compute_partial_slices(wav.len(), rate, min_coverage)?; + todo!() + } + } + + impl Module for Model { + fn forward(&self, xs: &Tensor) -> Result { + use candle_nn::RNN; + let mut xs = xs.clone(); + for lstm in self.lstms.iter() { + let res = lstm.seq(&xs)?; + let res: Vec<_> = res.into_iter().map(|s| s.h().clone()).collect(); + xs = Tensor::stack(&res, 1)?; + } + let embeds_raw = xs.apply(&self.linear)?.relu()?; + // TODO: normalize. + Ok(embeds_raw) + } + } +} + +type Rank = u32; + +pub mod tokenizers { + use super::*; + use std::collections::HashMap; + + pub struct BPE { + pub re: fancy_regex::Regex, + pub end_of_text: usize, + pub offset: usize, + pub ranks: HashMap, Rank>, + } + + impl BPE { + pub fn from_json(json: &serde_json::Value, end_of_text: usize) -> Result { + let json = match json.as_object() { + None => candle::bail!("json value is not an object"), + Some(json) => json, + }; + let re = match json.get("pat_str") { + None => candle::bail!("json object has no pat_str field"), + Some(pat_str) => match pat_str.as_str() { + None => candle::bail!("pat_str field is not a string"), + Some(pat_str) => fancy_regex::Regex::new(pat_str).map_err(E::wrap)?, + }, + }; + let offset = match json.get("offset") { + None => candle::bail!("json object has no offset field"), + Some(offset) => match offset.as_u64() { + None => candle::bail!("offset field is not a positive int"), + Some(offset) => offset as usize, + }, + }; + let mut ranks = HashMap::new(); + for id in 0u8..=255 { + ranks.insert(vec![id], id as u32); + } + let mergeable_ranks = match json.get("mergeable_ranks") { + None => candle::bail!("json object has no mergeable_ranks field"), + Some(mr) => match mr.as_object() { + None => candle::bail!("mergeable_ranks is not an object"), + Some(mr) => mr, + }, + }; + for (key, value) in mergeable_ranks.iter() { + let value = match value.as_u64() { + None => candle::bail!("mergeable_ranks '{key}' is not a u64"), + Some(value) => value as u32, + }; + if value < 256 { + continue; + } + // No escaping for other keys. + let key = key.as_bytes().to_vec(); + ranks.insert(key, value); + } + Ok(Self { + re, + end_of_text, + offset, + ranks, + }) + } + + // Taken from: + // https://github.com/openai/tiktoken/blob/1b9faf2779855124f05174adf1383e53689ed94b/src/lib.rs#L16C1-L82C2 + fn _byte_pair_merge(&self, piece: &[u8]) -> Vec<(usize, Rank)> { + // This is a vector of (start, rank). + // The rank is of the pair starting at position start. + let mut parts = Vec::with_capacity(piece.len() + 1); + + // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE + // the way we currently do, this is equivalent. An easy way to break this would be to decouple + // merge priority from token index or to prevent specific token merges. + let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX); + for i in 0..piece.len() - 1 { + let rank = *self.ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX); + if rank < min_rank.0 { + min_rank = (rank, i); + } + parts.push((i, rank)); + } + parts.push((piece.len() - 1, Rank::MAX)); + parts.push((piece.len(), Rank::MAX)); + + let get_rank = { + #[inline(always)] + |parts: &Vec<(usize, Rank)>, i: usize| { + if (i + 3) < parts.len() { + // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted + // parts[i + 1], see comment in the main loop. + *self + .ranks + .get(&piece[parts[i].0..parts[i + 3].0]) + .unwrap_or(&Rank::MAX) + } else { + Rank::MAX + } + } + }; + + // If you have n parts and m merges, this does O(mn) work. + // We could do something with a heap and do O(m log n) work. + // n is often very small so considerations like cache-locality outweigh the algorithmic + // complexity downsides of the `parts` vector. + while min_rank.0 != Rank::MAX { + let i = min_rank.1; + // Update parts[i] and parts[i - 1] before removing parts[i + 1], since + // `parts.remove(i + 1)` will thrash the cache. + if i > 0 { + parts[i - 1].1 = get_rank(&parts, i - 1); + } + parts[i].1 = get_rank(&parts, i); + parts.remove(i + 1); + + min_rank = (Rank::MAX, usize::MAX); + for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { + if rank < min_rank.0 { + min_rank = (rank, i); + } + } + } + parts + } + + pub fn byte_pair_encode(&self, piece: &[u8]) -> Vec { + if piece.is_empty() { + return Vec::new(); + } + if piece.len() == 1 { + return vec![self.ranks[piece]]; + } + assert!(piece.len() > 1); + self._byte_pair_merge(piece) + .windows(2) + .map(|part| self.ranks[&piece[part[0].0..part[1].0]]) + .collect() + } + + pub fn encode(&self, text: &str) -> Result> { + let mut bpe_tokens: Vec = Vec::new(); + for word in self.re.find_iter(text) { + let word = word.map_err(E::wrap)?; + let word_tokens = self.byte_pair_encode(word.as_str().as_bytes()); + for &token in word_tokens.iter() { + bpe_tokens.push(token + self.offset as u32) + } + } + bpe_tokens.push((self.end_of_text + self.offset) as u32); + Ok(bpe_tokens) + } + } +} + +pub mod gpt { + use super::*; + + #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] + pub enum NormType { + LayerNorm, + RMSNorm, + } + + #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] + pub enum AttnKernelType { + Fa2, + TorchAttn, + Hand, + } + + #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] + pub enum NonLinearityType { + Gelu, + Swiglu, + } + + enum Norm { + RMSNorm(candle_nn::RmsNorm), + LayerNorm(candle_nn::LayerNorm), + } + + // https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/model.py#L27 + #[derive(Debug, Clone)] + pub struct Config { + pub block_size: usize, + pub vocab_sizes: Vec, + pub target_vocab_sizes: Vec, + pub n_layer: usize, + pub n_head: usize, + pub n_embd: usize, + pub bias: bool, + pub causal: bool, + pub spk_emb_on_text: bool, + pub norm_type: NormType, + pub rmsnorm_eps: f64, + pub nonlinearity_type: NonLinearityType, + pub swiglu_multiple_of: Option, + pub attn_kernel_type: AttnKernelType, + pub kv_cache_enabled: bool, + } + + impl Config { + pub fn cfg1b_v0_1() -> Self { + Self { + n_layer: 6, + n_head: 6, + n_embd: 384, + block_size: 1024, + bias: false, + vocab_sizes: vec![1538, 1025], + causal: false, + target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025], + swiglu_multiple_of: Some(256), + norm_type: NormType::RMSNorm, + kv_cache_enabled: false, + attn_kernel_type: AttnKernelType::TorchAttn, + spk_emb_on_text: true, + nonlinearity_type: NonLinearityType::Gelu, + rmsnorm_eps: 1e-5, + } + } + } + + impl Norm { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + match cfg.norm_type { + NormType::RMSNorm => { + let rms_norm = candle_nn::rms_norm(cfg.n_embd, cfg.rmsnorm_eps, vb)?; + Ok(Self::RMSNorm(rms_norm)) + } + NormType::LayerNorm => { + let ln_cfg = candle_nn::LayerNormConfig { + affine: cfg.bias, + ..Default::default() + }; + let layer_norm = candle_nn::layer_norm(cfg.n_embd, ln_cfg, vb)?; + Ok(Self::LayerNorm(layer_norm)) + } + } + } + } + + impl Module for Norm { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::RMSNorm(m) => m.forward(xs), + Self::LayerNorm(m) => m.forward(xs), + } + } + } + + // https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/attn.py#L18 + struct SelfAttention { + c_attn: Linear, + c_proj: Linear, + n_head: usize, + } + + impl SelfAttention { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + // The different attention variants are likely to be identical but still we only accept + // TorchAttn for now. + if cfg.attn_kernel_type != AttnKernelType::TorchAttn { + candle::bail!("only TorchAttn is supported") + } + if cfg.kv_cache_enabled { + candle::bail!("kv_cache_enabled=true is not supported") + } + let c_attn = linear_b(cfg.n_embd, cfg.n_embd * 3, cfg.bias, vb.pp("c_attn"))?; + let c_proj = linear_b(cfg.n_embd, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?; + Ok(Self { + c_attn, + c_proj, + n_head: cfg.n_head, + }) + } + } + + impl Module for SelfAttention { + fn forward(&self, xs: &Tensor) -> Result { + let (b, t, c) = xs.dims3()?; + let c_x = xs + .apply(&self.c_attn)? + .reshape((b, t, 3, self.n_head, c / self.n_head))?; + let q = c_x.i((.., .., 0))?; + let k = c_x.i((.., .., 1))?; + let v = c_x.i((.., .., 2))?; + let q = q.transpose(1, 2)?.contiguous()?; + let k = k.transpose(1, 2)?.contiguous()?; + let v = v.transpose(1, 2)?.contiguous()?; + let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?; + // TODO: causal mask + let att = candle_nn::ops::softmax_last_dim(&att)?; + let att = att.matmul(&v)?.transpose(1, 2)?; + att.reshape((b, t, c))?.apply(&self.c_proj) + } + } + + // https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/layers.py#L43 + #[allow(clippy::upper_case_acronyms)] + enum MLP { + Gelu { + c_fc: Linear, + c_proj: Linear, + }, + Swiglu { + w1: Linear, + w3: Linear, + c_proj: Linear, + }, + } + + impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_dim = 4 * cfg.n_embd; + let slf = match cfg.nonlinearity_type { + NonLinearityType::Gelu => { + let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?; + let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?; + Self::Gelu { c_fc, c_proj } + } + NonLinearityType::Swiglu => { + let hidden_dim = (2 * hidden_dim) / 3; + let swiglu_multiple_of = match cfg.swiglu_multiple_of { + None => candle::bail!("swiglu-multiple-of has to be set"), + Some(smo) => smo, + }; + let hidden_dim = swiglu_multiple_of * (hidden_dim + swiglu_multiple_of - 1) + / swiglu_multiple_of; + let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?; + let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?; + let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?; + Self::Swiglu { w1, w3, c_proj } + } + }; + Ok(slf) + } + } + + impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Gelu { c_fc, c_proj } => xs.apply(c_fc)?.gelu()?.apply(c_proj), + Self::Swiglu { w1, w3, c_proj } => { + let w1 = xs.apply(w1)?; + let w3 = xs.apply(w3)?; + (w1.silu()? * w3)?.apply(c_proj) + } + } + } + } + + // https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/combined.py#L7 + struct Block { + ln_1: Norm, + ln_2: Norm, + attn: SelfAttention, + mlp: MLP, + } + + impl Block { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let ln_1 = Norm::new(cfg, vb.pp("ln_1"))?; + let ln_2 = Norm::new(cfg, vb.pp("ln_2"))?; + let attn = SelfAttention::new(cfg, vb.pp("attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + Ok(Block { + ln_1, + ln_2, + attn, + mlp, + }) + } + } + + impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result { + let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?; + let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?; + Ok(xs) + } + } + + // https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/model.py#L79 + #[allow(clippy::upper_case_acronyms)] + pub struct Model { + wtes: Vec, + wpe: candle_nn::Embedding, + h: Vec, + ln_f: Norm, + lm_heads: Vec, + cfg: Config, + } + + impl Model { + pub fn new(cfg: Config, vb: VarBuilder) -> Result { + let vb_t = vb.pp("transformer"); + let ln_f = Norm::new(&cfg, vb_t.pp("ln_f"))?; + let mut wtes = Vec::with_capacity(cfg.vocab_sizes.len()); + let vb_w = vb_t.pp("wtes"); + for (idx, vocab_size) in cfg.vocab_sizes.iter().enumerate() { + let wte = candle_nn::embedding(*vocab_size, cfg.n_embd, vb_w.pp(idx))?; + wtes.push(wte) + } + let wpe = candle_nn::embedding(cfg.block_size, cfg.n_embd, vb_t.pp("wpe"))?; + + let mut h = Vec::with_capacity(cfg.n_layer); + let vb_h = vb_t.pp("h"); + for idx in 0..cfg.n_layer { + let block = Block::new(&cfg, vb_h.pp(idx))?; + h.push(block) + } + + let mut lm_heads = Vec::with_capacity(cfg.target_vocab_sizes.len()); + let vb_l = vb.pp("lm_heads"); + for (idx, vocab_size) in cfg.target_vocab_sizes.iter().enumerate() { + let head = linear_b(cfg.n_embd, *vocab_size, false, vb_l.pp(idx))?; + lm_heads.push(head) + } + Ok(Self { + wtes, + wpe, + h, + ln_f, + lm_heads, + cfg, + }) + } + + pub fn config(&self) -> &Config { + &self.cfg + } + + pub fn forward(&self, idx: &Tensor) -> Result> { + let device = idx.device(); + let (b, _num_hierarchies, t) = idx.dims3()?; + let pos = Tensor::arange(0u32, t as u32, device)?; + let pos_emb = pos.apply(&self.wpe)?; + let mut tok_emb = Tensor::zeros((b, t, self.cfg.n_embd), DType::F32, device)?; + for (wte_idx, wte) in self.wtes.iter().enumerate() { + let emb = idx.i((.., wte_idx, ..))?.apply(wte)?; + tok_emb = (tok_emb + emb)?; + } + // TODO: speaker embs. + let spk_emb = 0f64; + let mut xs = (pos_emb.broadcast_add(&tok_emb)? + spk_emb)?; + for block in self.h.iter() { + xs = xs.apply(block)? + } + let xs = xs.apply(&self.ln_f)?; + let mut logits = Vec::with_capacity(self.lm_heads.len()); + for lm_head in self.lm_heads.iter() { + // non-causal mode only. + let ys = xs.apply(lm_head)?; + logits.push(ys) + } + Ok(logits) + } + } +} + +pub mod transformer { + use super::*; + + #[derive(Debug, Clone, serde::Deserialize)] + pub struct Config { + pub block_size: usize, + pub vocab_size: usize, + pub n_layer: usize, + pub n_head: usize, + pub dim: usize, + pub speaker_emb_dim: usize, + pub intermediate_size: Option, + pub n_local_heads: Option, + pub norm_eps: f64, + } + + impl Config { + pub fn cfg1b_v0_1() -> Self { + Self { + n_layer: 24, + n_head: 16, + dim: 2048, + vocab_size: 2562, + speaker_emb_dim: 256, + block_size: 2048, + intermediate_size: None, + n_local_heads: None, + norm_eps: 1e-5, + } + } + + fn n_local_heads(&self) -> usize { + self.n_local_heads.unwrap_or(self.n_head) + } + + fn head_dim(&self) -> usize { + self.dim / self.n_head + } + + fn intermediate_size(&self) -> usize { + match self.intermediate_size { + Some(intermediate_size) => intermediate_size, + None => { + let hidden_dim = self.dim * 4; + let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize; + (n_hidden + 255) / 256 * 256 + } + } + } + } + + #[derive(Debug, Clone)] + struct FeedForward { + w1: Linear, + w2: Linear, + w3: Linear, + } + + impl FeedForward { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let i_size = cfg.intermediate_size(); + let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?; + let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?; + let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?; + Ok(Self { w1, w2, w3 }) + } + } + + impl Module for FeedForward { + fn forward(&self, xs: &Tensor) -> Result { + let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?; + swiglu.apply(&self.w2) + } + } + + #[derive(Debug, Clone)] + struct Attention { + wqkv: Linear, + wo: Linear, + dim: usize, + kv_size: usize, + n_local_heads: usize, + head_dim: usize, + n_head: usize, + kv_cache: Option<(Tensor, Tensor)>, + } + + impl Attention { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let n_local_heads = cfg.n_local_heads(); + let head_dim = cfg.head_dim(); + let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim; + let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?; + let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?; + Ok(Self { + wqkv, + wo, + dim: cfg.dim, + kv_size: n_local_heads * head_dim, + n_local_heads, + head_dim, + n_head: cfg.n_head, + kv_cache: None, + }) + } + + fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result { + let (b_sz, seqlen, _) = xs.dims3()?; + + let qkv = xs.apply(&self.wqkv)?; + let q = qkv.narrow(D::Minus1, 0, self.dim)?; + let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?; + let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?; + let q = q + .reshape((b_sz, seqlen, self.n_head, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let k = k + .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))? + .transpose(1, 2)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &k], 2)?; + let v = Tensor::cat(&[prev_v, &v], 2)?; + (k, v) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?; + let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?; + + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + + let attn_weights = attn_weights.broadcast_add(mask)?; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights.matmul(&v)?; + attn_output + .transpose(1, 2)? + .reshape((b_sz, seqlen, self.dim))? + .apply(&self.wo) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } + } + + #[derive(Debug, Clone)] + struct Block { + attention: Attention, + feed_forward: FeedForward, + ffn_norm: RmsNorm, + attention_norm: RmsNorm, + } + + impl Block { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let attention = Attention::new(cfg, vb.pp("attention"))?; + let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?; + let ffn_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?; + let attention_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?; + Ok(Self { + attention, + feed_forward, + ffn_norm, + attention_norm, + }) + } + + fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result { + let hs = xs.apply(&self.attention_norm)?; + let hs = (xs + self.attention.forward(&hs, pos, mask))?; + &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward) + } + + fn clear_kv_cache(&mut self) { + self.attention.clear_kv_cache() + } + } + + #[derive(Debug, Clone)] + pub struct Model { + tok_embeddings: Embedding, + pos_embeddings: Embedding, + speaker_cond_pos: Linear, + layers: Vec, + norm: RmsNorm, + output: Linear, + spk_cond_mask: Tensor, + } + + impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let tok_embeddings = embedding(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?; + let pos_embeddings = embedding(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?; + let speaker_cond_pos = linear_b( + cfg.speaker_emb_dim, + cfg.dim, + false, + vb.pp("speaker_cond_pos"), + )?; + let mut layers = Vec::with_capacity(cfg.n_layer); + let vb_l = vb.pp("layers"); + for layer_idx in 0..cfg.n_layer { + let layer = Block::new(cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("norm"))?; + let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?; + let spk_cond_mask = Tensor::cat( + &[ + Tensor::ones((1, 1, cfg.dim), DType::F32, vb.device())?, + Tensor::zeros((1, 1, cfg.dim), DType::F32, vb.device())?, + ], + 0, + )?; + Ok(Self { + tok_embeddings, + pos_embeddings, + speaker_cond_pos, + layers, + norm, + output, + spk_cond_mask, + }) + } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } + + pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result { + let (_b_sz, seqlen) = xs.dims2()?; + let mask: Vec<_> = (0..seqlen) + .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(); + let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?; + let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?; + let tok_embeddings = xs.apply(&self.tok_embeddings)?; + let pos_embeddings = input_pos.apply(&self.pos_embeddings)?; + let mut xs = tok_embeddings + .broadcast_add(&pos_embeddings)? + .broadcast_add( + &spk_emb + .apply(&self.speaker_cond_pos)? + .broadcast_mul(&self.spk_cond_mask)?, + )?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, pos, &mask)? + } + xs.narrow(1, seqlen - 1, 1)? + .apply(&self.norm)? + .apply(&self.output) + } + } +} + +pub mod adapters { + // https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py + pub struct TiltedEncodec { + end_of_audio_token: u32, + } + + impl TiltedEncodec { + pub fn new(end_of_audio_token: u32) -> Self { + Self { end_of_audio_token } + } + + pub fn decode(&self, tokens: &[Vec]) -> (Vec, Vec>) { + let mut text_ids = vec![]; + let mut extracted_audio_ids = vec![]; + let mut min_audio_ids_len = usize::MAX; + for (book_id, tokens) in tokens.iter().enumerate() { + let mut audio_ids = vec![]; + for &t in tokens.iter() { + #[allow(clippy::comparison_chain)] + if t > self.end_of_audio_token { + if book_id == 0 { + text_ids.push(t) + } + } else if t < self.end_of_audio_token { + audio_ids.push(t) + } + } + min_audio_ids_len = usize::min(min_audio_ids_len, audio_ids.len()); + extracted_audio_ids.push(audio_ids) + } + for audio_ids in extracted_audio_ids.iter_mut() { + audio_ids.truncate(min_audio_ids_len) + } + (text_ids, extracted_audio_ids) + } + } + + // https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4 + pub struct FlattenedInterleavedEncodec2Codebook { + end_of_audio_token: u32, + } + + impl FlattenedInterleavedEncodec2Codebook { + pub fn new(end_of_audio_token: u32) -> Self { + Self { end_of_audio_token } + } + + pub fn decode(&self, tokens: &[u32]) -> (Vec, Vec, Vec) { + let mut text_ids = vec![]; + let mut audio_ids1 = vec![]; + let mut audio_ids2 = vec![]; + for &t in tokens.iter() { + #[allow(clippy::comparison_chain)] + if t < self.end_of_audio_token { + audio_ids1.push(t) + } else if t < 2 * self.end_of_audio_token { + audio_ids2.push(t - self.end_of_audio_token) + } else { + text_ids.push(t) + } + } + (text_ids, audio_ids1, audio_ids2) + } + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a5f03059..871c8107 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -18,6 +18,7 @@ pub mod llama2_c; pub mod llama2_c_weights; pub mod mamba; pub mod marian; +pub mod metavoice; pub mod mistral; pub mod mixformer; pub mod mixtral;