Sketching the musicgen model. (#66)

* Skeleton files for musicgen.

* Add a musicgen model module.

* Sketch the model loading.

* Start adding the forward pass.

* More forward pass.

* Positional embeddings.

* Forward for the decoder layers.

* Add an empty function.

* Fix the musicgen weight names.

* More musicgen modeling.

* Add the T5 loading bits.

* Add the encodec config.

* Add the encodec module hierarchy.

* More Encodec modeling.

* Encodec modeling.

* Encodec modeling.

* Add more to the encodec modeling.

* Load the weights.

* Populate the resnet blocks.

* Also load the conv transpose weights.

* Split musicgen in multiple files.
This commit is contained in:
Laurent Mazare
2023-07-09 19:53:35 +01:00
committed by GitHub
parent c187f347bf
commit ea5dfa69bc
6 changed files with 1497 additions and 1 deletions

View File

@ -1,5 +1,5 @@
#![allow(dead_code)]
// TODO: KV cache.
// TODO: Add an offline mode.
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

View File

@ -0,0 +1,482 @@
use crate::nn::{Conv1D, ConvConfig, VarBuilder};
use anyhow::Result;
use candle::Tensor;
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
#[derive(Debug, Clone, PartialEq)]
enum NormType {
WeightNorm,
None,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
target_bandwidths: Vec<f64>,
sampling_rate: usize,
audio_channels: usize,
normalize: bool,
chunk_length_s: Option<usize>,
overlap: Option<usize>,
hidden_size: usize,
num_filters: usize,
num_residual_layers: usize,
upsampling_ratios: Vec<usize>,
norm_type: NormType,
kernel_size: usize,
last_kernel_size: usize,
residual_kernel_size: usize,
dilation_growth_rate: usize,
use_causal_conv: bool,
pad_mode: &'static str,
compress: usize,
num_lstm_layers: usize,
trim_right_ratio: f64,
codebook_size: usize,
codebook_dim: Option<usize>,
use_conv_shortcut: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
sampling_rate: 24_000,
audio_channels: 1,
normalize: false,
chunk_length_s: None,
overlap: None,
hidden_size: 128,
num_filters: 32,
num_residual_layers: 1,
upsampling_ratios: vec![8, 5, 4, 2],
norm_type: NormType::WeightNorm,
kernel_size: 7,
last_kernel_size: 7,
residual_kernel_size: 3,
dilation_growth_rate: 2,
use_causal_conv: true,
pad_mode: "reflect",
compress: 2,
num_lstm_layers: 2,
trim_right_ratio: 1.0,
codebook_size: 1024,
codebook_dim: None,
use_conv_shortcut: true,
}
}
}
impl Config {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
pub fn musicgen_small() -> Self {
Self {
audio_channels: 1,
chunk_length_s: None,
codebook_dim: Some(128),
codebook_size: 2048,
compress: 2,
dilation_growth_rate: 2,
hidden_size: 128,
kernel_size: 7,
last_kernel_size: 7,
norm_type: NormType::WeightNorm,
normalize: false,
num_filters: 64,
num_lstm_layers: 2,
num_residual_layers: 1,
overlap: None,
pad_mode: "reflect",
residual_kernel_size: 3,
sampling_rate: 32_000,
target_bandwidths: vec![2.2],
trim_right_ratio: 1.0,
upsampling_ratios: vec![8, 5, 4, 4],
use_causal_conv: false,
use_conv_shortcut: false,
}
}
fn codebook_dim(&self) -> usize {
self.codebook_dim.unwrap_or(self.codebook_size)
}
fn frame_rate(&self) -> usize {
let hop_length: usize = self.upsampling_ratios.iter().product();
(self.sampling_rate + hop_length - 1) / hop_length
}
fn num_quantizers(&self) -> usize {
let num = 1000f64
* self
.target_bandwidths
.last()
.expect("empty target_bandwidths");
(num as usize) / (self.frame_rate() * 10)
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
#[derive(Debug)]
struct EncodecEuclideanCodebook {
inited: Tensor,
cluster_size: Tensor,
embed: Tensor,
embed_avg: Tensor,
}
impl EncodecEuclideanCodebook {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let inited = vb.get(1, &format!("{p}.inited"))?;
let cluster_size = vb.get(cfg.codebook_size, &format!("{p}.cluster_size"))?;
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
let embed = vb.get(e_shape, &format!("{p}.embed"))?;
let embed_avg = vb.get(e_shape, &format!("{p}.embed_avg"))?;
Ok(Self {
inited,
cluster_size,
embed,
embed_avg,
})
}
}
#[derive(Debug)]
struct EncodecVectorQuantization {
codebook: EncodecEuclideanCodebook,
}
impl EncodecVectorQuantization {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let codebook = EncodecEuclideanCodebook::load(&format!("{p}.codebook"), vb, cfg)?;
Ok(Self { codebook })
}
}
#[derive(Debug)]
struct EncodecResidualVectorQuantizer {
layers: Vec<EncodecVectorQuantization>,
}
impl EncodecResidualVectorQuantizer {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let p = format!("{p}.layers");
let layers = (0..cfg.num_quantizers())
.map(|i| EncodecVectorQuantization::load(&format!("{p}.{i}"), vb, cfg))
.collect::<Result<Vec<_>>>()?;
Ok(Self { layers })
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
#[derive(Debug)]
struct EncodecLSTM {
layers: Vec<(Tensor, Tensor, Tensor, Tensor)>,
}
impl EncodecLSTM {
fn load(dim: usize, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let p = format!("{p}.lstm");
let mut layers = vec![];
for i in 0..cfg.num_lstm_layers {
let w_hh = vb.get((4 * dim, dim), &format!("{p}.weight_hh_l{i}"))?;
let w_ih = vb.get((4 * dim, dim), &format!("{p}.weight_ih_l{i}"))?;
let b_hh = vb.get(4 * dim, &format!("{p}.bias_hh_l{i}"))?;
let b_ih = vb.get(4 * dim, &format!("{p}.bias_ih_l{i}"))?;
layers.push((w_hh, w_ih, b_hh, b_ih))
}
Ok(Self { layers })
}
}
#[derive(Debug)]
struct EncodecConvTranspose1d {
weight_g: Tensor,
weight_v: Tensor,
bias: Tensor,
}
impl EncodecConvTranspose1d {
fn load(
in_c: usize,
out_c: usize,
k: usize,
_stride: usize,
p: &str,
vb: &VarBuilder,
_cfg: &Config,
) -> Result<Self> {
let p = format!("{p}.conv");
let weight_g = vb.get((in_c, 1, 1), &format!("{p}.weight_g"))?;
let weight_v = vb.get((in_c, out_c, k), &format!("{p}.weight_v"))?;
let bias = vb.get(out_c, &format!("{p}.bias"))?;
Ok(Self {
weight_g,
weight_v,
bias,
})
}
}
#[derive(Debug)]
struct EncodecConv1d {
conv: Conv1D,
}
impl EncodecConv1d {
fn load(
in_c: usize,
out_c: usize,
kernel_size: usize,
stride: usize,
p: &str,
vb: &VarBuilder,
cfg: &Config,
) -> Result<Self> {
let conv = match cfg.norm_type {
NormType::WeightNorm => Conv1D::load_weight_norm(
in_c,
out_c,
kernel_size,
ConvConfig { padding: 0, stride },
&format!("{p}.conv"),
vb,
)?,
NormType::None => Conv1D::load(
in_c,
out_c,
kernel_size,
ConvConfig { padding: 0, stride },
&format!("{p}.conv"),
vb,
)?,
};
Ok(Self { conv })
}
}
#[derive(Debug)]
struct EncodecResnetBlock {
block_conv1: EncodecConv1d,
block_conv2: EncodecConv1d,
shortcut: Option<EncodecConv1d>,
}
impl EncodecResnetBlock {
fn load(
dim: usize,
dilations: &[usize],
p: &str,
vb: &VarBuilder,
cfg: &Config,
) -> Result<Self> {
let h = dim / cfg.compress;
let mut layer = Layer::new(format!("{p}.block"));
if dilations.len() != 2 {
anyhow::bail!("expected dilations of size 2")
}
// TODO: Apply dilations!
layer.inc();
let block_conv1 = EncodecConv1d::load(
dim,
h,
cfg.residual_kernel_size,
1,
&layer.next_name(),
vb,
cfg,
)?;
layer.inc();
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, &layer.next_name(), vb, cfg)?;
let shortcut = if cfg.use_conv_shortcut {
let conv = EncodecConv1d::load(dim, dim, 1, 1, &format!("{p}.shortcut"), vb, cfg)?;
Some(conv)
} else {
None
};
Ok(Self {
block_conv1,
block_conv2,
shortcut,
})
}
}
#[derive(Debug)]
struct Layer {
prefix: String,
cnt: usize,
}
impl Layer {
fn new(prefix: String) -> Self {
Self { prefix, cnt: 0 }
}
fn inc(&mut self) {
self.cnt += 1;
}
fn next_name(&mut self) -> String {
let name = format!("{}.{}", self.prefix, self.cnt);
self.cnt += 1;
name
}
}
#[derive(Debug)]
struct EncodecEncoder {
init_conv: EncodecConv1d,
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
final_lstm: EncodecLSTM,
final_conv: EncodecConv1d,
}
impl EncodecEncoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new(format!("{p}.layers"));
let init_conv = EncodecConv1d::load(
cfg.audio_channels,
cfg.num_filters,
cfg.kernel_size,
1,
&layer.next_name(),
vb,
cfg,
)?;
let mut sampling_layers = vec![];
let mut scaling = 1;
for &ratio in cfg.upsampling_ratios.iter().rev() {
let current_scale = scaling * cfg.num_filters;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::load(
current_scale,
&[cfg.dilation_growth_rate.pow(j), 1],
&layer.next_name(),
vb,
cfg,
)?;
resnets.push(resnet)
}
layer.inc(); // ELU
let conv1d = EncodecConv1d::load(
current_scale,
current_scale * 2,
ratio * 2,
ratio,
&layer.next_name(),
vb,
cfg,
)?;
sampling_layers.push((resnets, conv1d));
scaling *= 2;
}
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?;
layer.inc(); // ELU
let final_conv = EncodecConv1d::load(
cfg.num_filters * scaling,
cfg.hidden_size,
cfg.last_kernel_size,
1,
&layer.next_name(),
vb,
cfg,
)?;
Ok(Self {
init_conv,
sampling_layers,
final_conv,
final_lstm,
})
}
}
#[derive(Debug)]
struct EncodecDecoder {
init_conv: EncodecConv1d,
init_lstm: EncodecLSTM,
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
final_conv: EncodecConv1d,
}
impl EncodecDecoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new(format!("{p}.layers"));
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
let init_conv = EncodecConv1d::load(
cfg.hidden_size,
cfg.num_filters * scaling,
cfg.last_kernel_size,
1,
&layer.next_name(),
vb,
cfg,
)?;
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?;
let mut sampling_layers = vec![];
for &ratio in cfg.upsampling_ratios.iter() {
let current_scale = scaling * cfg.num_filters;
layer.inc(); // ELU
let conv1d = EncodecConvTranspose1d::load(
current_scale,
current_scale / 2,
ratio * 2,
ratio,
&layer.next_name(),
vb,
cfg,
)?;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::load(
current_scale / 2,
&[cfg.dilation_growth_rate.pow(j), 1],
&layer.next_name(),
vb,
cfg,
)?;
resnets.push(resnet)
}
sampling_layers.push((conv1d, resnets));
scaling /= 2;
}
layer.inc(); // ELU
let final_conv = EncodecConv1d::load(
cfg.num_filters,
cfg.audio_channels,
cfg.last_kernel_size,
1,
&layer.next_name(),
vb,
cfg,
)?;
Ok(Self {
init_conv,
init_lstm,
sampling_layers,
final_conv,
})
}
}
#[derive(Debug)]
pub struct EncodecModel {
encoder: EncodecEncoder,
decoder: EncodecDecoder,
quantizer: EncodecResidualVectorQuantizer,
}
impl EncodecModel {
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let encoder = EncodecEncoder::load(&format!("{p}.encoder"), vb, cfg)?;
let decoder = EncodecDecoder::load(&format!("{p}.decoder"), vb, cfg)?;
let quantizer = EncodecResidualVectorQuantizer::load(&format!("{p}.quantizer"), vb, cfg)?;
Ok(Self {
encoder,
decoder,
quantizer,
})
}
}

View File

@ -0,0 +1,59 @@
#![allow(dead_code)]
// https://huggingface.co/facebook/musicgen-small/tree/main
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/modeling_musicgen.py
// TODO: Add an offline mode.
// TODO: Add a KV cache.
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod encodec_model;
mod musicgen_model;
mod nn;
mod t5_model;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
use nn::VarBuilder;
use anyhow::{Error as E, Result};
use candle::{DType, Device};
use clap::Parser;
const DTYPE: DType = DType::F32;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: String,
/// The tokenizer config.
#[arg(long)]
tokenizer: String,
}
fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
let device = if args.cpu {
Device::Cpu
} else {
Device::new_cuda(0)?
};
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
let model = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small();
let _model = MusicgenForConditionalGeneration::load(&vb, config)?;
Ok(())
}

View File

@ -0,0 +1,412 @@
use crate::nn::{Embedding, HiddenAct, LayerNorm, Linear, VarBuilder};
use crate::{encodec_model, t5_model};
use anyhow::Result;
use candle::{DType, Device, Tensor, D};
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
vocab_size: usize,
max_position_embeddings: usize,
num_hidden_layers: usize,
ffn_dim: usize,
num_attention_heads: usize,
layerdrop: f64,
use_cache: bool,
activation_function: HiddenAct,
hidden_size: usize,
dropout: f64,
attention_dropout: f64,
activation_dropout: f64,
initializer_factor: f64,
scale_embedding: bool,
num_codebooks: usize,
pad_token_id: usize,
bos_token_id: usize,
eos_token_id: Option<usize>,
tie_word_embeddings: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
vocab_size: 2048,
max_position_embeddings: 2048,
num_hidden_layers: 24,
ffn_dim: 4096,
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,
activation_dropout: 0.0,
initializer_factor: 0.02,
scale_embedding: false,
num_codebooks: 4,
pad_token_id: 2048,
bos_token_id: 2048,
eos_token_id: None,
tie_word_embeddings: false,
}
}
}
impl Config {
fn musicgen_small() -> Self {
Self {
vocab_size: 2048,
max_position_embeddings: 2048,
num_hidden_layers: 24,
ffn_dim: 4096,
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,
activation_dropout: 0.0,
initializer_factor: 0.02,
scale_embedding: false,
num_codebooks: 4,
pad_token_id: 2048,
bos_token_id: 2048,
eos_token_id: None,
tie_word_embeddings: false,
}
}
}
fn get_embedding(num_embeddings: usize, embedding_dim: usize) -> Result<Tensor> {
let half_dim = embedding_dim / 2;
let emb = f64::ln(10000.) / (half_dim - 1) as f64;
let xs: Vec<_> = (0..num_embeddings).map(|v| v as f32).collect();
let xs = Tensor::from_vec(xs, (num_embeddings, 1), &Device::Cpu)?;
let ys: Vec<_> = (0..half_dim)
.map(|v| f64::exp(v as f64 * -emb) as f32)
.collect();
let ys = Tensor::from_vec(ys, (1, half_dim), &Device::Cpu)?;
let shape = (num_embeddings, half_dim);
let emb = (xs.broadcast_as(shape)? * ys.broadcast_as(shape)?)?;
let emb =
Tensor::cat(&[&emb.cos()?, &emb.sin()?], 1)?.reshape((num_embeddings, 2 * half_dim))?;
let emb = if embedding_dim % 2 == 1 {
let zeros = Tensor::zeros((num_embeddings, 1), DType::F32, &Device::Cpu)?;
Tensor::cat(&[&emb, &zeros], 1)?
} else {
emb
};
Ok(emb)
}
#[derive(Debug)]
struct MusicgenSinusoidalPositionalEmbedding {
num_positions: usize,
embedding_dim: usize,
weights: Tensor,
}
impl MusicgenSinusoidalPositionalEmbedding {
fn load(_vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let num_positions = cfg.max_position_embeddings;
let embedding_dim = cfg.hidden_size;
let weights = get_embedding(num_positions, embedding_dim)?;
Ok(Self {
num_positions,
embedding_dim,
weights,
})
}
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (_b_sz, _codebooks, seq_len) = input_ids.shape().r3()?;
if seq_len > self.weights.dim(0)? {
self.weights = get_embedding(seq_len, self.embedding_dim)?
}
Ok(self.weights.narrow(0, 0, seq_len)?)
}
}
#[derive(Debug)]
struct MusicgenAttention {
scaling: f64,
is_decoder: bool,
num_heads: usize,
head_dim: usize,
k_proj: Linear,
v_proj: Linear,
q_proj: Linear,
out_proj: Linear,
}
impl MusicgenAttention {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let head_dim = h / num_heads;
let k_proj = Linear::load(h, h, false, &format!("{p}.k_proj"), vb)?;
let v_proj = Linear::load(h, h, false, &format!("{p}.v_proj"), vb)?;
let q_proj = Linear::load(h, h, false, &format!("{p}.q_proj"), vb)?;
let out_proj = Linear::load(h, h, false, &format!("{p}.out_proj"), vb)?;
Ok(Self {
scaling: 1. / (head_dim as f64).sqrt(),
is_decoder: true,
num_heads,
head_dim,
k_proj,
v_proj,
q_proj,
out_proj,
})
}
fn forward(
&mut self,
xs: &Tensor,
kv_states: Option<&Tensor>,
attention_mask: &Tensor,
) -> Result<Tensor> {
let (b_sz, tgt_len, _) = xs.shape().r3()?;
let query_states = (self.q_proj.forward(xs)? * self.scaling)?;
let kv_states = kv_states.unwrap_or(xs);
let key_states = self.k_proj.forward(kv_states)?;
let value_states = self.v_proj.forward(kv_states)?;
let tgt = (b_sz, tgt_len, self.num_heads, self.head_dim);
let query_states = query_states.reshape(tgt)?.transpose(1, 2)?.contiguous()?;
let key_states = key_states.reshape(tgt)?.transpose(1, 2)?.contiguous()?;
let value_states = value_states.reshape(tgt)?.transpose(1, 2)?.contiguous()?;
let src_len = key_states.dim(1)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let attn_weights = attn_weights
.reshape((b_sz, self.num_heads, tgt_len, src_len))?
.broadcast_add(attention_mask)?;
let attn_weights = attn_weights.softmax(D::Minus1)?;
// TODO: layer_head_mask?
let attn_output = attn_weights
.matmul(&value_states)?
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
.transpose(1, 2)?
.reshape((b_sz, tgt_len, self.num_heads * self.head_dim))?;
let attn_output = self.out_proj.forward(&attn_output)?;
Ok(attn_output)
}
}
#[derive(Debug)]
struct MusicgenDecoderLayer {
self_attn: MusicgenAttention,
self_attn_layer_norm: LayerNorm,
encoder_attn: MusicgenAttention,
encoder_attn_layer_norm: LayerNorm,
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
activation_fn: HiddenAct,
}
impl MusicgenDecoderLayer {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?;
let self_attn_layer_norm =
LayerNorm::load(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?;
let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?;
let encoder_attn_layer_norm =
LayerNorm::load(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?;
let fc1 = Linear::load(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?;
let fc2 = Linear::load(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?;
let final_layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?;
Ok(Self {
self_attn,
self_attn_layer_norm,
encoder_attn,
encoder_attn_layer_norm,
fc1,
fc2,
final_layer_norm,
activation_fn: cfg.activation_function,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: &Tensor,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs.clone();
let xs = self.self_attn_layer_norm.forward(xs)?;
let xs = self.self_attn.forward(&xs, None, attention_mask)?;
let mut xs = (xs + residual)?;
if let Some(encoder_hidden_states) = &encoder_hidden_states {
let residual = xs.clone();
let encoder_attention_mask = attention_mask.clone(); // TODO
xs = self.encoder_attn.forward(
&xs,
Some(encoder_hidden_states),
&encoder_attention_mask,
)?;
xs = (xs + residual)?
}
let residual = xs.clone();
let xs = self.final_layer_norm.forward(&xs)?;
let xs = self.fc1.forward(&xs)?;
let xs = self.activation_fn.forward(&xs)?;
let xs = self.fc2.forward(&xs)?;
let xs = (xs + residual)?;
Ok(xs)
}
}
#[derive(Debug)]
struct MusicgenDecoder {
embed_tokens: Vec<Embedding>,
embed_positions: MusicgenSinusoidalPositionalEmbedding,
layers: Vec<MusicgenDecoderLayer>,
layer_norm: LayerNorm,
embed_scale: f64,
num_codebooks: usize,
d_model: usize,
}
impl MusicgenDecoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let embed_scale = if cfg.scale_embedding {
(h as f64).sqrt()
} else {
1.
};
let embed_dim = cfg.vocab_size + 1;
let embed_tokens = (0..cfg.num_codebooks)
.map(|i| Embedding::load(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb))
.collect::<Result<Vec<_>>>()?;
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb, cfg)?;
let layers = (0..cfg.num_hidden_layers)
.map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg))
.collect::<Result<Vec<_>>>()?;
let layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.layer_norm"), vb)?;
Ok(Self {
embed_tokens,
embed_positions,
layers,
layer_norm,
embed_scale,
num_codebooks: cfg.num_codebooks,
d_model: cfg.hidden_size,
})
}
fn prepare_decoder_attention_mask(&self, _b_sz: usize, _seq_len: usize) -> Result<Tensor> {
todo!()
}
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let dev = input_ids.device();
let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
let b_sz = b_sz_times_codebooks / self.num_codebooks;
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, &dev)?;
for (idx, codebook) in self.embed_tokens.iter().enumerate() {
let inp = input.narrow(1, idx, 1)?.squeeze(1)?;
inputs_embeds = (inputs_embeds + codebook.forward(&inp)?)?
}
let inputs_embeds = inputs_embeds;
let positions = self.embed_positions.forward(&input)?.to_device(&dev)?;
let mut xs = inputs_embeds.broadcast_add(&positions)?;
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() {
xs = decoder_layer.forward(&xs, &attention_mask, None)?;
}
let xs = self.layer_norm.forward(&xs)?;
Ok(xs)
}
}
#[derive(Debug)]
pub struct MusicgenForCausalLM {
decoder: MusicgenDecoder,
lm_heads: Vec<Linear>,
num_codebooks: usize,
vocab_size: usize,
}
impl MusicgenForCausalLM {
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?;
let lm_heads = (0..cfg.num_codebooks)
.map(|i| Linear::load(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
decoder,
lm_heads,
num_codebooks: cfg.num_codebooks,
vocab_size: cfg.vocab_size,
})
}
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len) = input_ids.shape().r2()?;
let hidden_states = self.decoder.forward(input_ids)?;
let lm_logits = self
.lm_heads
.iter()
.map(|h| Ok(h.forward(&hidden_states)?))
.collect::<Result<Vec<_>>>()?;
let lm_logits = Tensor::stack(&lm_logits, 1)?.reshape((
b_sz * self.num_codebooks,
seq_len,
self.vocab_size,
))?;
Ok(lm_logits)
}
}
#[derive(Debug)]
pub struct MusicgenForConditionalGeneration {
text_encoder: crate::t5_model::T5EncoderModel,
audio_encoder: crate::encodec_model::EncodecModel,
decoder: MusicgenForCausalLM,
cfg: GenConfig,
}
#[derive(Debug, Clone, PartialEq)]
pub struct GenConfig {
musicgen: Config,
t5: crate::t5_model::Config,
encodec: crate::encodec_model::Config,
}
impl GenConfig {
pub fn small() -> Self {
Self {
musicgen: Config::musicgen_small(),
t5: t5_model::Config::musicgen_small(),
encodec: encodec_model::Config::musicgen_small(),
}
}
}
impl MusicgenForConditionalGeneration {
pub fn config(&self) -> &GenConfig {
&self.cfg
}
pub fn load(vb: &VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5_model::T5EncoderModel::load("text_encoder", vb, &cfg.t5)?;
let audio_encoder = encodec_model::EncodecModel::load("audio_encoder", vb, &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load("decoder", vb, &cfg.musicgen)?;
Ok(Self {
text_encoder,
audio_encoder,
decoder,
cfg,
})
}
}

View File

@ -0,0 +1,255 @@
#![allow(dead_code)]
use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use std::collections::HashMap;
const MAX_SEQ_LEN: usize = 5000;
pub struct VarBuilder<'a> {
safetensors: Option<(HashMap<String, usize>, Vec<SafeTensors<'a>>)>,
dtype: DType,
device: Device,
}
impl<'a> VarBuilder<'a> {
pub fn from_safetensors(
safetensors: Vec<SafeTensors<'a>>,
dtype: DType,
device: &Device,
) -> Self {
let mut routing = HashMap::new();
for (index, sf) in safetensors.iter().enumerate() {
for k in sf.names() {
routing.insert(k.to_string(), index);
}
}
Self {
safetensors: Some((routing, safetensors)),
device: device.clone(),
dtype,
}
}
pub fn zeros(dtype: DType, device: &Device) -> Self {
Self {
safetensors: None,
device: device.clone(),
dtype,
}
}
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
let s: Shape = s.into();
match &self.safetensors {
None => Tensor::zeros(s, self.dtype, &self.device),
Some((routing, safetensors)) => {
// Unwrap or 0 just to let the proper error flow.
let index = routing.get(tensor_name).unwrap_or(&0);
let tensor = safetensors[*index]
.tensor(tensor_name, &self.device)?
.to_dtype(self.dtype)?;
if *tensor.shape() != s {
let msg = format!("shape mismatch for {tensor_name}");
Err(candle::Error::UnexpectedShape {
msg,
expected: s,
got: tensor.shape().clone(),
})?
}
Ok(tensor)
}
}
}
}
#[derive(Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
}
impl Linear {
pub fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
let bias = if bias {
Some(vb.get(size2, &format!("{p}.bias"))?)
} else {
None
};
Ok(Self { weight, bias })
}
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let (bsize, _, _) = x.shape().r3()?;
let w = self.weight.broadcast_left(bsize)?.t()?;
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
Self { weight, bias, eps }
}
pub fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
let (weight, bias) = match (
vb.get(size, &format!("{p}.weight")),
vb.get(size, &format!("{p}.bias")),
) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (
vb.get(size, &format!("{p}.gamma")),
vb.get(size, &format!("{p}.beta")),
) {
(weight, bias)
} else {
return Err(err.into());
}
}
};
Ok(Self { weight, bias, eps })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let dtype = x.dtype();
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let x = x.to_dtype(DType::F32)?;
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
}
#[derive(Debug)]
pub struct Dropout {
pr: f64,
}
impl Dropout {
pub fn new(pr: f64) -> Self {
Self { pr }
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
// TODO
Ok(x.clone())
}
}
#[derive(Debug)]
pub struct Embedding {
embeddings: Tensor,
hidden_size: usize,
}
impl Embedding {
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
Self {
embeddings,
hidden_size,
}
}
pub fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
Ok(Self::new(embeddings, hidden_size))
}
pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
let values = Tensor::embedding(&indexes, &self.embeddings)?;
let values = values.reshape(final_dims)?;
Ok(values)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConvConfig {
pub padding: usize,
pub stride: usize,
}
#[derive(Debug)]
pub struct Conv1D {
weight: Tensor,
bias: Option<Tensor>,
config: ConvConfig,
}
impl Conv1D {
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
pub fn load_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: ConvConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Self> {
let weight_g = vb.get((out_c, 1, 1), &format!("{p}.weight_g"))?;
let weight_v = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight_v"))?;
let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, &format!("{p}.bias"))?;
Ok(Self {
weight,
bias: Some(bias),
config,
})
}
pub fn load(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: ConvConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Self> {
let weight = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight"))?;
let bias = vb.get(out_c, &format!("{p}.bias"))?;
Ok(Self {
weight,
bias: Some(bias),
config,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HiddenAct {
Gelu,
Relu,
}
impl HiddenAct {
pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Gelu => xs.gelu(),
Self::Relu => xs.relu(),
}
}
}

View File

@ -0,0 +1,288 @@
// T5 Text Encoder
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use crate::nn::{Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use anyhow::Result;
use candle::Tensor;
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
vocab_size: usize,
d_model: usize,
d_kv: usize,
d_ff: usize,
num_layers: usize,
num_decoder_layers: Option<usize>,
num_heads: usize,
relative_attention_num_buckets: usize,
relative_attention_max_distance: usize,
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
feed_forward_proj: HiddenAct,
is_decoder: bool,
is_encoder_decoder: bool,
use_cache: bool,
pad_token_id: usize,
eos_token_id: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
vocab_size: 32128,
d_model: 512,
d_kv: 64,
d_ff: 2048,
num_layers: 6,
num_decoder_layers: None,
num_heads: 8,
relative_attention_num_buckets: 32,
relative_attention_max_distance: 128,
dropout_rate: 0.1,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: HiddenAct::Relu,
is_decoder: false,
is_encoder_decoder: true,
use_cache: true,
pad_token_id: 0,
eos_token_id: 1,
}
}
}
impl Config {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184
pub fn musicgen_small() -> Self {
Self {
d_ff: 3072,
d_kv: 64,
d_model: 768,
dropout_rate: 0.1,
eos_token_id: 1,
feed_forward_proj: HiddenAct::Relu,
initializer_factor: 1.0,
is_decoder: false,
is_encoder_decoder: true,
layer_norm_epsilon: 1e-6,
num_decoder_layers: Some(12),
num_heads: 12,
num_layers: 12,
pad_token_id: 0,
relative_attention_max_distance: 128,
relative_attention_num_buckets: 32,
use_cache: true,
vocab_size: 32128,
}
}
}
#[derive(Debug)]
struct T5LayerNorm {
weight: Tensor,
variance_epsilon: f64,
}
impl T5LayerNorm {
fn load(h: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get(h, &format!("{p}.weight"))?;
Ok(Self {
weight,
variance_epsilon: eps,
})
}
}
#[derive(Debug)]
struct T5DenseActDense {
wi: Linear,
wo: Linear,
dropout: Dropout,
act: HiddenAct,
}
impl T5DenseActDense {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let wi = Linear::load(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
let wo = Linear::load(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
wi,
wo,
dropout,
act: HiddenAct::Relu,
})
}
}
#[derive(Debug)]
struct T5LayerFF {
dense_relu_dense: T5DenseActDense,
layer_norm: T5LayerNorm,
dropout: Dropout,
}
impl T5LayerFF {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
// is_gated_act is not supported.
let dense_relu_dense = T5DenseActDense::load(&format!("{p}.DenseReluDense"), vb, cfg)?;
let layer_norm = T5LayerNorm::load(
cfg.d_model,
cfg.layer_norm_epsilon,
&format!("{p}.layer_norm"),
vb,
)?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
dense_relu_dense,
layer_norm,
dropout,
})
}
}
#[derive(Debug)]
struct T5Attention {
q: Linear,
k: Linear,
v: Linear,
o: Linear,
relative_attention_bias: Option<Embedding>,
}
impl T5Attention {
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
let q = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
let k = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
let v = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
let o = Linear::load(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
let relative_attention_bias = if h {
let emb = Embedding::load(
cfg.relative_attention_num_buckets,
cfg.num_heads,
&format!("{p}.relative_attention_bias"),
vb,
)?;
Some(emb)
} else {
None
};
Ok(Self {
q,
k,
v,
o,
relative_attention_bias,
})
}
}
#[derive(Debug)]
struct T5LayerSelfAttention {
self_attention: T5Attention,
layer_norm: T5LayerNorm,
dropout: Dropout,
}
impl T5LayerSelfAttention {
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let self_attention = T5Attention::load(h, &format!("{p}.SelfAttention"), vb, cfg)?;
let layer_norm = T5LayerNorm::load(
cfg.d_model,
cfg.layer_norm_epsilon,
&format!("{p}.layer_norm"),
vb,
)?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
self_attention,
layer_norm,
dropout,
})
}
}
#[derive(Debug)]
struct T5LayerCrossAttention {}
impl T5LayerCrossAttention {
fn load(_p: &str, _vb: &VarBuilder, _cfg: &Config) -> Result<Self> {
todo!()
}
}
#[derive(Debug)]
struct T5Block {
self_attn: T5LayerSelfAttention,
cross_attn: Option<T5LayerCrossAttention>,
ff: T5LayerFF,
}
impl T5Block {
fn load(
has_relative_attention_bias: bool,
p: &str,
vb: &VarBuilder,
cfg: &Config,
) -> Result<Self> {
let p = &format!("{p}.layer");
let self_attn =
T5LayerSelfAttention::load(has_relative_attention_bias, &format!("{p}.0"), vb, cfg)?;
let cross_attn = if cfg.is_decoder {
Some(T5LayerCrossAttention::load(&format!("{p}.1"), vb, cfg)?)
} else {
None
};
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
let ff = T5LayerFF::load(&format!("{p}.{ff_i}"), vb, cfg)?;
Ok(Self {
self_attn,
cross_attn,
ff,
})
}
}
#[derive(Debug)]
struct T5Stack {
// TODO: Add embed_tokens if needed (shared embedding layer).
block: Vec<T5Block>,
final_layer_norm: T5LayerNorm,
dropout: Dropout,
}
impl T5Stack {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let block = (0..cfg.num_layers)
.map(|i| T5Block::load(i == 0, &format!("{p}.block.{i}"), vb, cfg))
.collect::<Result<Vec<_>>>()?;
let final_layer_norm = T5LayerNorm::load(
cfg.d_model,
cfg.layer_norm_epsilon,
&format!("{p}.final_layer_norm"),
vb,
)?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
block,
final_layer_norm,
dropout,
})
}
}
#[derive(Debug)]
pub struct T5EncoderModel {
shared: Embedding,
encoder: T5Stack,
}
impl T5EncoderModel {
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let shared = Embedding::load(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
let encoder = T5Stack::load(&format!("{p}.encoder"), vb, cfg)?;
Ok(Self { shared, encoder })
}
}