Files
candle/candle-examples/examples/musicgen/musicgen_model.rs
Laurent Mazare 26cd266e65 Musicgen text embeddings. (#726)
* Musicgen text embeddings.

* Bugfix for layer norm.

* Proper position bias.

* Expose the weights.
2023-09-03 18:27:48 +01:00

414 lines
14 KiB
Rust

use crate::{encodec_model, t5_model};
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder,
};
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
vocab_size: usize,
max_position_embeddings: usize,
num_hidden_layers: usize,
ffn_dim: usize,
num_attention_heads: usize,
layerdrop: f64,
use_cache: bool,
activation_function: Activation,
hidden_size: usize,
dropout: f64,
attention_dropout: f64,
activation_dropout: f64,
initializer_factor: f64,
scale_embedding: bool,
num_codebooks: usize,
pad_token_id: usize,
bos_token_id: usize,
eos_token_id: Option<usize>,
tie_word_embeddings: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
vocab_size: 2048,
max_position_embeddings: 2048,
num_hidden_layers: 24,
ffn_dim: 4096,
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,
activation_dropout: 0.0,
initializer_factor: 0.02,
scale_embedding: false,
num_codebooks: 4,
pad_token_id: 2048,
bos_token_id: 2048,
eos_token_id: None,
tie_word_embeddings: false,
}
}
}
impl Config {
fn musicgen_small() -> Self {
Self {
vocab_size: 2048,
max_position_embeddings: 2048,
num_hidden_layers: 24,
ffn_dim: 4096,
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,
activation_dropout: 0.0,
initializer_factor: 0.02,
scale_embedding: false,
num_codebooks: 4,
pad_token_id: 2048,
bos_token_id: 2048,
eos_token_id: None,
tie_word_embeddings: false,
}
}
}
fn get_embedding(num_embeddings: usize, embedding_dim: usize) -> Result<Tensor> {
let half_dim = embedding_dim / 2;
let emb = f64::ln(10000.) / (half_dim - 1) as f64;
let xs: Vec<_> = (0..num_embeddings).map(|v| v as f32).collect();
let xs = Tensor::from_vec(xs, (num_embeddings, 1), &Device::Cpu)?;
let ys: Vec<_> = (0..half_dim)
.map(|v| f64::exp(v as f64 * -emb) as f32)
.collect();
let ys = Tensor::from_vec(ys, (1, half_dim), &Device::Cpu)?;
let shape = (num_embeddings, half_dim);
let emb = (xs.broadcast_as(shape)? * ys.broadcast_as(shape)?)?;
let emb =
Tensor::cat(&[&emb.cos()?, &emb.sin()?], 1)?.reshape((num_embeddings, 2 * half_dim))?;
let emb = if embedding_dim % 2 == 1 {
let zeros = Tensor::zeros((num_embeddings, 1), DType::F32, &Device::Cpu)?;
Tensor::cat(&[&emb, &zeros], 1)?
} else {
emb
};
Ok(emb)
}
#[derive(Debug)]
struct MusicgenSinusoidalPositionalEmbedding {
num_positions: usize,
embedding_dim: usize,
weights: Tensor,
}
impl MusicgenSinusoidalPositionalEmbedding {
fn load(_vb: VarBuilder, cfg: &Config) -> Result<Self> {
let num_positions = cfg.max_position_embeddings;
let embedding_dim = cfg.hidden_size;
let weights = get_embedding(num_positions, embedding_dim)?;
Ok(Self {
num_positions,
embedding_dim,
weights,
})
}
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (_b_sz, _codebooks, seq_len) = input_ids.dims3()?;
if seq_len > self.weights.dim(0)? {
self.weights = get_embedding(seq_len, self.embedding_dim)?
}
self.weights.narrow(0, 0, seq_len)
}
}
#[derive(Debug)]
struct MusicgenAttention {
scaling: f64,
is_decoder: bool,
num_heads: usize,
head_dim: usize,
k_proj: Linear,
v_proj: Linear,
q_proj: Linear,
out_proj: Linear,
}
impl MusicgenAttention {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let head_dim = h / num_heads;
let k_proj = linear_no_bias(h, h, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(h, h, vb.pp("v_proj"))?;
let q_proj = linear_no_bias(h, h, vb.pp("q_proj"))?;
let out_proj = linear_no_bias(h, h, vb.pp("out_proj"))?;
Ok(Self {
scaling: 1. / (head_dim as f64).sqrt(),
is_decoder: true,
num_heads,
head_dim,
k_proj,
v_proj,
q_proj,
out_proj,
})
}
fn forward(
&mut self,
xs: &Tensor,
kv_states: Option<&Tensor>,
attention_mask: &Tensor,
) -> Result<Tensor> {
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (self.q_proj.forward(xs)? * self.scaling)?;
let kv_states = kv_states.unwrap_or(xs);
let key_states = self.k_proj.forward(kv_states)?;
let value_states = self.v_proj.forward(kv_states)?;
let tgt = (b_sz, tgt_len, self.num_heads, self.head_dim);
let query_states = query_states.reshape(tgt)?.transpose(1, 2)?.contiguous()?;
let key_states = key_states.reshape(tgt)?.transpose(1, 2)?.contiguous()?;
let value_states = value_states.reshape(tgt)?.transpose(1, 2)?.contiguous()?;
let src_len = key_states.dim(1)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let attn_weights = attn_weights
.reshape((b_sz, self.num_heads, tgt_len, src_len))?
.broadcast_add(attention_mask)?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
// TODO: layer_head_mask?
let attn_output = attn_weights
.matmul(&value_states)?
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
.transpose(1, 2)?
.reshape((b_sz, tgt_len, self.num_heads * self.head_dim))?;
let attn_output = self.out_proj.forward(&attn_output)?;
Ok(attn_output)
}
}
#[derive(Debug)]
struct MusicgenDecoderLayer {
self_attn: MusicgenAttention,
self_attn_layer_norm: LayerNorm,
encoder_attn: MusicgenAttention,
encoder_attn_layer_norm: LayerNorm,
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
activation_fn: Activation,
}
impl MusicgenDecoderLayer {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let self_attn = MusicgenAttention::load(vb.pp("self_attn"), cfg)?;
let self_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn = MusicgenAttention::load(vb.pp("encoder_attn"), cfg)?;
let encoder_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear_no_bias(h, cfg.ffn_dim, vb.pp("fc1"))?;
let fc2 = linear_no_bias(cfg.ffn_dim, h, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(h, 1e-5, vb.pp("final_layer_norm"))?;
Ok(Self {
self_attn,
self_attn_layer_norm,
encoder_attn,
encoder_attn_layer_norm,
fc1,
fc2,
final_layer_norm,
activation_fn: cfg.activation_function,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: &Tensor,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs.clone();
let xs = self.self_attn_layer_norm.forward(xs)?;
let xs = self.self_attn.forward(&xs, None, attention_mask)?;
let mut xs = (xs + residual)?;
if let Some(encoder_hidden_states) = &encoder_hidden_states {
let residual = xs.clone();
let encoder_attention_mask = attention_mask.clone(); // TODO
xs = self.encoder_attn.forward(
&xs,
Some(encoder_hidden_states),
&encoder_attention_mask,
)?;
xs = (xs + residual)?
}
let residual = xs.clone();
let xs = self.final_layer_norm.forward(&xs)?;
let xs = self.fc1.forward(&xs)?;
let xs = self.activation_fn.forward(&xs)?;
let xs = self.fc2.forward(&xs)?;
let xs = (xs + residual)?;
Ok(xs)
}
}
#[derive(Debug)]
struct MusicgenDecoder {
embed_tokens: Vec<Embedding>,
embed_positions: MusicgenSinusoidalPositionalEmbedding,
layers: Vec<MusicgenDecoderLayer>,
layer_norm: LayerNorm,
embed_scale: f64,
num_codebooks: usize,
d_model: usize,
}
impl MusicgenDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let embed_scale = if cfg.scale_embedding {
(h as f64).sqrt()
} else {
1.
};
let embed_dim = cfg.vocab_size + 1;
let embed_tokens = (0..cfg.num_codebooks)
.map(|i| embedding(embed_dim, h, vb.pp(&format!("embed_tokens.{i}"))))
.collect::<Result<Vec<_>>>()?;
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?;
let layers = (0..cfg.num_hidden_layers)
.map(|i| MusicgenDecoderLayer::load(vb.pp(&format!("layers.{i}")), cfg))
.collect::<Result<Vec<_>>>()?;
let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?;
Ok(Self {
embed_tokens,
embed_positions,
layers,
layer_norm,
embed_scale,
num_codebooks: cfg.num_codebooks,
d_model: cfg.hidden_size,
})
}
fn prepare_decoder_attention_mask(&self, _b_sz: usize, _seq_len: usize) -> Result<Tensor> {
todo!()
}
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let dev = input_ids.device();
let (b_sz_times_codebooks, seq_len) = input_ids.dims2()?;
let b_sz = b_sz_times_codebooks / self.num_codebooks;
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
for (idx, codebook) in self.embed_tokens.iter().enumerate() {
let inp = input.narrow(1, idx, 1)?.squeeze(1)?;
inputs_embeds = (inputs_embeds + codebook.forward(&inp)?)?
}
let inputs_embeds = inputs_embeds;
let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
let mut xs = inputs_embeds.broadcast_add(&positions)?;
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() {
xs = decoder_layer.forward(&xs, &attention_mask, None)?;
}
let xs = self.layer_norm.forward(&xs)?;
Ok(xs)
}
}
#[derive(Debug)]
pub struct MusicgenForCausalLM {
decoder: MusicgenDecoder,
lm_heads: Vec<Linear>,
num_codebooks: usize,
vocab_size: usize,
}
impl MusicgenForCausalLM {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
let lm_heads = (0..cfg.num_codebooks)
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}"))))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
decoder,
lm_heads,
num_codebooks: cfg.num_codebooks,
vocab_size: cfg.vocab_size,
})
}
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len) = input_ids.dims2()?;
let hidden_states = self.decoder.forward(input_ids)?;
let lm_logits = self
.lm_heads
.iter()
.map(|h| h.forward(&hidden_states))
.collect::<Result<Vec<_>>>()?;
let lm_logits = Tensor::stack(&lm_logits, 1)?.reshape((
b_sz * self.num_codebooks,
seq_len,
self.vocab_size,
))?;
Ok(lm_logits)
}
}
#[derive(Debug)]
pub struct MusicgenForConditionalGeneration {
pub text_encoder: crate::t5_model::T5EncoderModel,
pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM,
cfg: GenConfig,
}
#[derive(Debug, Clone, PartialEq)]
pub struct GenConfig {
musicgen: Config,
t5: crate::t5_model::Config,
encodec: crate::encodec_model::Config,
}
impl GenConfig {
pub fn small() -> Self {
Self {
musicgen: Config::musicgen_small(),
t5: t5_model::Config::musicgen_small(),
encodec: encodec_model::Config::musicgen_small(),
}
}
}
impl MusicgenForConditionalGeneration {
pub fn config(&self) -> &GenConfig {
&self.cfg
}
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let audio_encoder =
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
Ok(Self {
text_encoder,
audio_encoder,
decoder,
cfg,
})
}
}