mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Sketching the musicgen model. (#66)
* Skeleton files for musicgen. * Add a musicgen model module. * Sketch the model loading. * Start adding the forward pass. * More forward pass. * Positional embeddings. * Forward for the decoder layers. * Add an empty function. * Fix the musicgen weight names. * More musicgen modeling. * Add the T5 loading bits. * Add the encodec config. * Add the encodec module hierarchy. * More Encodec modeling. * Encodec modeling. * Encodec modeling. * Add more to the encodec modeling. * Load the weights. * Populate the resnet blocks. * Also load the conv transpose weights. * Split musicgen in multiple files.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
#![allow(dead_code)]
|
||||
// TODO: KV cache.
|
||||
// TODO: Add an offline mode.
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
482
candle-examples/examples/musicgen/encodec_model.rs
Normal file
482
candle-examples/examples/musicgen/encodec_model.rs
Normal file
@ -0,0 +1,482 @@
|
||||
use crate::nn::{Conv1D, ConvConfig, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::Tensor;
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NormType {
|
||||
WeightNorm,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
target_bandwidths: Vec<f64>,
|
||||
sampling_rate: usize,
|
||||
audio_channels: usize,
|
||||
normalize: bool,
|
||||
chunk_length_s: Option<usize>,
|
||||
overlap: Option<usize>,
|
||||
hidden_size: usize,
|
||||
num_filters: usize,
|
||||
num_residual_layers: usize,
|
||||
upsampling_ratios: Vec<usize>,
|
||||
norm_type: NormType,
|
||||
kernel_size: usize,
|
||||
last_kernel_size: usize,
|
||||
residual_kernel_size: usize,
|
||||
dilation_growth_rate: usize,
|
||||
use_causal_conv: bool,
|
||||
pad_mode: &'static str,
|
||||
compress: usize,
|
||||
num_lstm_layers: usize,
|
||||
trim_right_ratio: f64,
|
||||
codebook_size: usize,
|
||||
codebook_dim: Option<usize>,
|
||||
use_conv_shortcut: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
|
||||
sampling_rate: 24_000,
|
||||
audio_channels: 1,
|
||||
normalize: false,
|
||||
chunk_length_s: None,
|
||||
overlap: None,
|
||||
hidden_size: 128,
|
||||
num_filters: 32,
|
||||
num_residual_layers: 1,
|
||||
upsampling_ratios: vec![8, 5, 4, 2],
|
||||
norm_type: NormType::WeightNorm,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
residual_kernel_size: 3,
|
||||
dilation_growth_rate: 2,
|
||||
use_causal_conv: true,
|
||||
pad_mode: "reflect",
|
||||
compress: 2,
|
||||
num_lstm_layers: 2,
|
||||
trim_right_ratio: 1.0,
|
||||
codebook_size: 1024,
|
||||
codebook_dim: None,
|
||||
use_conv_shortcut: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
||||
pub fn musicgen_small() -> Self {
|
||||
Self {
|
||||
audio_channels: 1,
|
||||
chunk_length_s: None,
|
||||
codebook_dim: Some(128),
|
||||
codebook_size: 2048,
|
||||
compress: 2,
|
||||
dilation_growth_rate: 2,
|
||||
hidden_size: 128,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
norm_type: NormType::WeightNorm,
|
||||
normalize: false,
|
||||
num_filters: 64,
|
||||
num_lstm_layers: 2,
|
||||
num_residual_layers: 1,
|
||||
overlap: None,
|
||||
pad_mode: "reflect",
|
||||
residual_kernel_size: 3,
|
||||
sampling_rate: 32_000,
|
||||
target_bandwidths: vec![2.2],
|
||||
trim_right_ratio: 1.0,
|
||||
upsampling_ratios: vec![8, 5, 4, 4],
|
||||
use_causal_conv: false,
|
||||
use_conv_shortcut: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn codebook_dim(&self) -> usize {
|
||||
self.codebook_dim.unwrap_or(self.codebook_size)
|
||||
}
|
||||
|
||||
fn frame_rate(&self) -> usize {
|
||||
let hop_length: usize = self.upsampling_ratios.iter().product();
|
||||
(self.sampling_rate + hop_length - 1) / hop_length
|
||||
}
|
||||
|
||||
fn num_quantizers(&self) -> usize {
|
||||
let num = 1000f64
|
||||
* self
|
||||
.target_bandwidths
|
||||
.last()
|
||||
.expect("empty target_bandwidths");
|
||||
(num as usize) / (self.frame_rate() * 10)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
|
||||
#[derive(Debug)]
|
||||
struct EncodecEuclideanCodebook {
|
||||
inited: Tensor,
|
||||
cluster_size: Tensor,
|
||||
embed: Tensor,
|
||||
embed_avg: Tensor,
|
||||
}
|
||||
|
||||
impl EncodecEuclideanCodebook {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inited = vb.get(1, &format!("{p}.inited"))?;
|
||||
let cluster_size = vb.get(cfg.codebook_size, &format!("{p}.cluster_size"))?;
|
||||
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
||||
let embed = vb.get(e_shape, &format!("{p}.embed"))?;
|
||||
let embed_avg = vb.get(e_shape, &format!("{p}.embed_avg"))?;
|
||||
Ok(Self {
|
||||
inited,
|
||||
cluster_size,
|
||||
embed,
|
||||
embed_avg,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecVectorQuantization {
|
||||
codebook: EncodecEuclideanCodebook,
|
||||
}
|
||||
|
||||
impl EncodecVectorQuantization {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let codebook = EncodecEuclideanCodebook::load(&format!("{p}.codebook"), vb, cfg)?;
|
||||
Ok(Self { codebook })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecResidualVectorQuantizer {
|
||||
layers: Vec<EncodecVectorQuantization>,
|
||||
}
|
||||
|
||||
impl EncodecResidualVectorQuantizer {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let p = format!("{p}.layers");
|
||||
let layers = (0..cfg.num_quantizers())
|
||||
.map(|i| EncodecVectorQuantization::load(&format!("{p}.{i}"), vb, cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||
#[derive(Debug)]
|
||||
struct EncodecLSTM {
|
||||
layers: Vec<(Tensor, Tensor, Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
fn load(dim: usize, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let p = format!("{p}.lstm");
|
||||
let mut layers = vec![];
|
||||
for i in 0..cfg.num_lstm_layers {
|
||||
let w_hh = vb.get((4 * dim, dim), &format!("{p}.weight_hh_l{i}"))?;
|
||||
let w_ih = vb.get((4 * dim, dim), &format!("{p}.weight_ih_l{i}"))?;
|
||||
let b_hh = vb.get(4 * dim, &format!("{p}.bias_hh_l{i}"))?;
|
||||
let b_ih = vb.get(4 * dim, &format!("{p}.bias_ih_l{i}"))?;
|
||||
layers.push((w_hh, w_ih, b_hh, b_ih))
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecConvTranspose1d {
|
||||
weight_g: Tensor,
|
||||
weight_v: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl EncodecConvTranspose1d {
|
||||
fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k: usize,
|
||||
_stride: usize,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
_cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let p = format!("{p}.conv");
|
||||
let weight_g = vb.get((in_c, 1, 1), &format!("{p}.weight_g"))?;
|
||||
let weight_v = vb.get((in_c, out_c, k), &format!("{p}.weight_v"))?;
|
||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
||||
Ok(Self {
|
||||
weight_g,
|
||||
weight_v,
|
||||
bias,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecConv1d {
|
||||
conv: Conv1D,
|
||||
}
|
||||
|
||||
impl EncodecConv1d {
|
||||
fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let conv = match cfg.norm_type {
|
||||
NormType::WeightNorm => Conv1D::load_weight_norm(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
ConvConfig { padding: 0, stride },
|
||||
&format!("{p}.conv"),
|
||||
vb,
|
||||
)?,
|
||||
NormType::None => Conv1D::load(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
ConvConfig { padding: 0, stride },
|
||||
&format!("{p}.conv"),
|
||||
vb,
|
||||
)?,
|
||||
};
|
||||
Ok(Self { conv })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecResnetBlock {
|
||||
block_conv1: EncodecConv1d,
|
||||
block_conv2: EncodecConv1d,
|
||||
shortcut: Option<EncodecConv1d>,
|
||||
}
|
||||
|
||||
impl EncodecResnetBlock {
|
||||
fn load(
|
||||
dim: usize,
|
||||
dilations: &[usize],
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let h = dim / cfg.compress;
|
||||
let mut layer = Layer::new(format!("{p}.block"));
|
||||
if dilations.len() != 2 {
|
||||
anyhow::bail!("expected dilations of size 2")
|
||||
}
|
||||
// TODO: Apply dilations!
|
||||
layer.inc();
|
||||
let block_conv1 = EncodecConv1d::load(
|
||||
dim,
|
||||
h,
|
||||
cfg.residual_kernel_size,
|
||||
1,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
cfg,
|
||||
)?;
|
||||
layer.inc();
|
||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, &layer.next_name(), vb, cfg)?;
|
||||
let shortcut = if cfg.use_conv_shortcut {
|
||||
let conv = EncodecConv1d::load(dim, dim, 1, 1, &format!("{p}.shortcut"), vb, cfg)?;
|
||||
Some(conv)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
block_conv1,
|
||||
block_conv2,
|
||||
shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Layer {
|
||||
prefix: String,
|
||||
cnt: usize,
|
||||
}
|
||||
|
||||
impl Layer {
|
||||
fn new(prefix: String) -> Self {
|
||||
Self { prefix, cnt: 0 }
|
||||
}
|
||||
|
||||
fn inc(&mut self) {
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
fn next_name(&mut self) -> String {
|
||||
let name = format!("{}.{}", self.prefix, self.cnt);
|
||||
self.cnt += 1;
|
||||
name
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecEncoder {
|
||||
init_conv: EncodecConv1d,
|
||||
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
|
||||
final_lstm: EncodecLSTM,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl EncodecEncoder {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(format!("{p}.layers"));
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.audio_channels,
|
||||
cfg.num_filters,
|
||||
cfg.kernel_size,
|
||||
1,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
cfg,
|
||||
)?;
|
||||
let mut sampling_layers = vec![];
|
||||
let mut scaling = 1;
|
||||
for &ratio in cfg.upsampling_ratios.iter().rev() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConv1d::load(
|
||||
current_scale,
|
||||
current_scale * 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
cfg,
|
||||
)?;
|
||||
sampling_layers.push((resnets, conv1d));
|
||||
scaling *= 2;
|
||||
}
|
||||
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?;
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters * scaling,
|
||||
cfg.hidden_size,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
final_lstm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecDecoder {
|
||||
init_conv: EncodecConv1d,
|
||||
init_lstm: EncodecLSTM,
|
||||
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl EncodecDecoder {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(format!("{p}.layers"));
|
||||
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.hidden_size,
|
||||
cfg.num_filters * scaling,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
cfg,
|
||||
)?;
|
||||
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?;
|
||||
let mut sampling_layers = vec![];
|
||||
for &ratio in cfg.upsampling_ratios.iter() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConvTranspose1d::load(
|
||||
current_scale,
|
||||
current_scale / 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
cfg,
|
||||
)?;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale / 2,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
sampling_layers.push((conv1d, resnets));
|
||||
scaling /= 2;
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters,
|
||||
cfg.audio_channels,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
init_lstm,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EncodecModel {
|
||||
encoder: EncodecEncoder,
|
||||
decoder: EncodecDecoder,
|
||||
quantizer: EncodecResidualVectorQuantizer,
|
||||
}
|
||||
|
||||
impl EncodecModel {
|
||||
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let encoder = EncodecEncoder::load(&format!("{p}.encoder"), vb, cfg)?;
|
||||
let decoder = EncodecDecoder::load(&format!("{p}.decoder"), vb, cfg)?;
|
||||
let quantizer = EncodecResidualVectorQuantizer::load(&format!("{p}.quantizer"), vb, cfg)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
quantizer,
|
||||
})
|
||||
}
|
||||
}
|
59
candle-examples/examples/musicgen/main.rs
Normal file
59
candle-examples/examples/musicgen/main.rs
Normal file
@ -0,0 +1,59 @@
|
||||
#![allow(dead_code)]
|
||||
// https://huggingface.co/facebook/musicgen-small/tree/main
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/modeling_musicgen.py
|
||||
// TODO: Add an offline mode.
|
||||
// TODO: Add a KV cache.
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod encodec_model;
|
||||
mod musicgen_model;
|
||||
mod nn;
|
||||
mod t5_model;
|
||||
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
use nn::VarBuilder;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device};
|
||||
use clap::Parser;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
|
||||
/// The tokenizer config.
|
||||
#[arg(long)]
|
||||
tokenizer: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
let device = if args.cpu {
|
||||
Device::Cpu
|
||||
} else {
|
||||
Device::new_cuda(0)?
|
||||
};
|
||||
|
||||
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
|
||||
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
|
||||
let model = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let config = GenConfig::small();
|
||||
let _model = MusicgenForConditionalGeneration::load(&vb, config)?;
|
||||
Ok(())
|
||||
}
|
412
candle-examples/examples/musicgen/musicgen_model.rs
Normal file
412
candle-examples/examples/musicgen/musicgen_model.rs
Normal file
@ -0,0 +1,412 @@
|
||||
use crate::nn::{Embedding, HiddenAct, LayerNorm, Linear, VarBuilder};
|
||||
use crate::{encodec_model, t5_model};
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
max_position_embeddings: usize,
|
||||
num_hidden_layers: usize,
|
||||
ffn_dim: usize,
|
||||
num_attention_heads: usize,
|
||||
layerdrop: f64,
|
||||
use_cache: bool,
|
||||
activation_function: HiddenAct,
|
||||
hidden_size: usize,
|
||||
dropout: f64,
|
||||
attention_dropout: f64,
|
||||
activation_dropout: f64,
|
||||
initializer_factor: f64,
|
||||
scale_embedding: bool,
|
||||
num_codebooks: usize,
|
||||
pad_token_id: usize,
|
||||
bos_token_id: usize,
|
||||
eos_token_id: Option<usize>,
|
||||
tie_word_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 2048,
|
||||
max_position_embeddings: 2048,
|
||||
num_hidden_layers: 24,
|
||||
ffn_dim: 4096,
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
activation_dropout: 0.0,
|
||||
initializer_factor: 0.02,
|
||||
scale_embedding: false,
|
||||
num_codebooks: 4,
|
||||
pad_token_id: 2048,
|
||||
bos_token_id: 2048,
|
||||
eos_token_id: None,
|
||||
tie_word_embeddings: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn musicgen_small() -> Self {
|
||||
Self {
|
||||
vocab_size: 2048,
|
||||
max_position_embeddings: 2048,
|
||||
num_hidden_layers: 24,
|
||||
ffn_dim: 4096,
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
activation_dropout: 0.0,
|
||||
initializer_factor: 0.02,
|
||||
scale_embedding: false,
|
||||
num_codebooks: 4,
|
||||
pad_token_id: 2048,
|
||||
bos_token_id: 2048,
|
||||
eos_token_id: None,
|
||||
tie_word_embeddings: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_embedding(num_embeddings: usize, embedding_dim: usize) -> Result<Tensor> {
|
||||
let half_dim = embedding_dim / 2;
|
||||
let emb = f64::ln(10000.) / (half_dim - 1) as f64;
|
||||
let xs: Vec<_> = (0..num_embeddings).map(|v| v as f32).collect();
|
||||
let xs = Tensor::from_vec(xs, (num_embeddings, 1), &Device::Cpu)?;
|
||||
let ys: Vec<_> = (0..half_dim)
|
||||
.map(|v| f64::exp(v as f64 * -emb) as f32)
|
||||
.collect();
|
||||
let ys = Tensor::from_vec(ys, (1, half_dim), &Device::Cpu)?;
|
||||
let shape = (num_embeddings, half_dim);
|
||||
let emb = (xs.broadcast_as(shape)? * ys.broadcast_as(shape)?)?;
|
||||
let emb =
|
||||
Tensor::cat(&[&emb.cos()?, &emb.sin()?], 1)?.reshape((num_embeddings, 2 * half_dim))?;
|
||||
let emb = if embedding_dim % 2 == 1 {
|
||||
let zeros = Tensor::zeros((num_embeddings, 1), DType::F32, &Device::Cpu)?;
|
||||
Tensor::cat(&[&emb, &zeros], 1)?
|
||||
} else {
|
||||
emb
|
||||
};
|
||||
Ok(emb)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MusicgenSinusoidalPositionalEmbedding {
|
||||
num_positions: usize,
|
||||
embedding_dim: usize,
|
||||
weights: Tensor,
|
||||
}
|
||||
|
||||
impl MusicgenSinusoidalPositionalEmbedding {
|
||||
fn load(_vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let num_positions = cfg.max_position_embeddings;
|
||||
let embedding_dim = cfg.hidden_size;
|
||||
let weights = get_embedding(num_positions, embedding_dim)?;
|
||||
Ok(Self {
|
||||
num_positions,
|
||||
embedding_dim,
|
||||
weights,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _codebooks, seq_len) = input_ids.shape().r3()?;
|
||||
if seq_len > self.weights.dim(0)? {
|
||||
self.weights = get_embedding(seq_len, self.embedding_dim)?
|
||||
}
|
||||
Ok(self.weights.narrow(0, 0, seq_len)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MusicgenAttention {
|
||||
scaling: f64,
|
||||
is_decoder: bool,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
q_proj: Linear,
|
||||
out_proj: Linear,
|
||||
}
|
||||
|
||||
impl MusicgenAttention {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let head_dim = h / num_heads;
|
||||
let k_proj = Linear::load(h, h, false, &format!("{p}.k_proj"), vb)?;
|
||||
let v_proj = Linear::load(h, h, false, &format!("{p}.v_proj"), vb)?;
|
||||
let q_proj = Linear::load(h, h, false, &format!("{p}.q_proj"), vb)?;
|
||||
let out_proj = Linear::load(h, h, false, &format!("{p}.out_proj"), vb)?;
|
||||
Ok(Self {
|
||||
scaling: 1. / (head_dim as f64).sqrt(),
|
||||
is_decoder: true,
|
||||
num_heads,
|
||||
head_dim,
|
||||
k_proj,
|
||||
v_proj,
|
||||
q_proj,
|
||||
out_proj,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
kv_states: Option<&Tensor>,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, tgt_len, _) = xs.shape().r3()?;
|
||||
let query_states = (self.q_proj.forward(xs)? * self.scaling)?;
|
||||
|
||||
let kv_states = kv_states.unwrap_or(xs);
|
||||
let key_states = self.k_proj.forward(kv_states)?;
|
||||
let value_states = self.v_proj.forward(kv_states)?;
|
||||
|
||||
let tgt = (b_sz, tgt_len, self.num_heads, self.head_dim);
|
||||
let query_states = query_states.reshape(tgt)?.transpose(1, 2)?.contiguous()?;
|
||||
let key_states = key_states.reshape(tgt)?.transpose(1, 2)?.contiguous()?;
|
||||
let value_states = value_states.reshape(tgt)?.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let src_len = key_states.dim(1)?;
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
let attn_weights = attn_weights
|
||||
.reshape((b_sz, self.num_heads, tgt_len, src_len))?
|
||||
.broadcast_add(attention_mask)?;
|
||||
let attn_weights = attn_weights.softmax(D::Minus1)?;
|
||||
// TODO: layer_head_mask?
|
||||
let attn_output = attn_weights
|
||||
.matmul(&value_states)?
|
||||
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, tgt_len, self.num_heads * self.head_dim))?;
|
||||
let attn_output = self.out_proj.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MusicgenDecoderLayer {
|
||||
self_attn: MusicgenAttention,
|
||||
self_attn_layer_norm: LayerNorm,
|
||||
encoder_attn: MusicgenAttention,
|
||||
encoder_attn_layer_norm: LayerNorm,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
final_layer_norm: LayerNorm,
|
||||
activation_fn: HiddenAct,
|
||||
}
|
||||
|
||||
impl MusicgenDecoderLayer {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?;
|
||||
let self_attn_layer_norm =
|
||||
LayerNorm::load(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||
let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?;
|
||||
let encoder_attn_layer_norm =
|
||||
LayerNorm::load(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||
let fc1 = Linear::load(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?;
|
||||
let fc2 = Linear::load(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?;
|
||||
let final_layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
self_attn_layer_norm,
|
||||
encoder_attn,
|
||||
encoder_attn_layer_norm,
|
||||
fc1,
|
||||
fc2,
|
||||
final_layer_norm,
|
||||
activation_fn: cfg.activation_function,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = self.self_attn_layer_norm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, None, attention_mask)?;
|
||||
let mut xs = (xs + residual)?;
|
||||
if let Some(encoder_hidden_states) = &encoder_hidden_states {
|
||||
let residual = xs.clone();
|
||||
let encoder_attention_mask = attention_mask.clone(); // TODO
|
||||
xs = self.encoder_attn.forward(
|
||||
&xs,
|
||||
Some(encoder_hidden_states),
|
||||
&encoder_attention_mask,
|
||||
)?;
|
||||
xs = (xs + residual)?
|
||||
}
|
||||
let residual = xs.clone();
|
||||
let xs = self.final_layer_norm.forward(&xs)?;
|
||||
let xs = self.fc1.forward(&xs)?;
|
||||
let xs = self.activation_fn.forward(&xs)?;
|
||||
let xs = self.fc2.forward(&xs)?;
|
||||
let xs = (xs + residual)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MusicgenDecoder {
|
||||
embed_tokens: Vec<Embedding>,
|
||||
embed_positions: MusicgenSinusoidalPositionalEmbedding,
|
||||
layers: Vec<MusicgenDecoderLayer>,
|
||||
layer_norm: LayerNorm,
|
||||
embed_scale: f64,
|
||||
num_codebooks: usize,
|
||||
d_model: usize,
|
||||
}
|
||||
|
||||
impl MusicgenDecoder {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let embed_scale = if cfg.scale_embedding {
|
||||
(h as f64).sqrt()
|
||||
} else {
|
||||
1.
|
||||
};
|
||||
let embed_dim = cfg.vocab_size + 1;
|
||||
let embed_tokens = (0..cfg.num_codebooks)
|
||||
.map(|i| Embedding::load(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb, cfg)?;
|
||||
let layers = (0..cfg.num_hidden_layers)
|
||||
.map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let layer_norm = LayerNorm::load(h, 1e-5, &format!("{p}.layer_norm"), vb)?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
layers,
|
||||
layer_norm,
|
||||
embed_scale,
|
||||
num_codebooks: cfg.num_codebooks,
|
||||
d_model: cfg.hidden_size,
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(&self, _b_sz: usize, _seq_len: usize) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let dev = input_ids.device();
|
||||
let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
|
||||
let b_sz = b_sz_times_codebooks / self.num_codebooks;
|
||||
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
|
||||
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, &dev)?;
|
||||
for (idx, codebook) in self.embed_tokens.iter().enumerate() {
|
||||
let inp = input.narrow(1, idx, 1)?.squeeze(1)?;
|
||||
inputs_embeds = (inputs_embeds + codebook.forward(&inp)?)?
|
||||
}
|
||||
let inputs_embeds = inputs_embeds;
|
||||
let positions = self.embed_positions.forward(&input)?.to_device(&dev)?;
|
||||
let mut xs = inputs_embeds.broadcast_add(&positions)?;
|
||||
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
|
||||
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() {
|
||||
xs = decoder_layer.forward(&xs, &attention_mask, None)?;
|
||||
}
|
||||
let xs = self.layer_norm.forward(&xs)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MusicgenForCausalLM {
|
||||
decoder: MusicgenDecoder,
|
||||
lm_heads: Vec<Linear>,
|
||||
num_codebooks: usize,
|
||||
vocab_size: usize,
|
||||
}
|
||||
|
||||
impl MusicgenForCausalLM {
|
||||
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?;
|
||||
let lm_heads = (0..cfg.num_codebooks)
|
||||
.map(|i| Linear::load(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
decoder,
|
||||
lm_heads,
|
||||
num_codebooks: cfg.num_codebooks,
|
||||
vocab_size: cfg.vocab_size,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.shape().r2()?;
|
||||
let hidden_states = self.decoder.forward(input_ids)?;
|
||||
let lm_logits = self
|
||||
.lm_heads
|
||||
.iter()
|
||||
.map(|h| Ok(h.forward(&hidden_states)?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let lm_logits = Tensor::stack(&lm_logits, 1)?.reshape((
|
||||
b_sz * self.num_codebooks,
|
||||
seq_len,
|
||||
self.vocab_size,
|
||||
))?;
|
||||
Ok(lm_logits)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MusicgenForConditionalGeneration {
|
||||
text_encoder: crate::t5_model::T5EncoderModel,
|
||||
audio_encoder: crate::encodec_model::EncodecModel,
|
||||
decoder: MusicgenForCausalLM,
|
||||
cfg: GenConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct GenConfig {
|
||||
musicgen: Config,
|
||||
t5: crate::t5_model::Config,
|
||||
encodec: crate::encodec_model::Config,
|
||||
}
|
||||
|
||||
impl GenConfig {
|
||||
pub fn small() -> Self {
|
||||
Self {
|
||||
musicgen: Config::musicgen_small(),
|
||||
t5: t5_model::Config::musicgen_small(),
|
||||
encodec: encodec_model::Config::musicgen_small(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MusicgenForConditionalGeneration {
|
||||
pub fn config(&self) -> &GenConfig {
|
||||
&self.cfg
|
||||
}
|
||||
|
||||
pub fn load(vb: &VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||
let text_encoder = t5_model::T5EncoderModel::load("text_encoder", vb, &cfg.t5)?;
|
||||
let audio_encoder = encodec_model::EncodecModel::load("audio_encoder", vb, &cfg.encodec)?;
|
||||
let decoder = MusicgenForCausalLM::load("decoder", vb, &cfg.musicgen)?;
|
||||
Ok(Self {
|
||||
text_encoder,
|
||||
audio_encoder,
|
||||
decoder,
|
||||
cfg,
|
||||
})
|
||||
}
|
||||
}
|
255
candle-examples/examples/musicgen/nn.rs
Normal file
255
candle-examples/examples/musicgen/nn.rs
Normal file
@ -0,0 +1,255 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
pub struct VarBuilder<'a> {
|
||||
safetensors: Option<(HashMap<String, usize>, Vec<SafeTensors<'a>>)>,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
pub fn from_safetensors(
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let mut routing = HashMap::new();
|
||||
for (index, sf) in safetensors.iter().enumerate() {
|
||||
for k in sf.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
}
|
||||
Self {
|
||||
safetensors: Some((routing, safetensors)),
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn zeros(dtype: DType, device: &Device) -> Self {
|
||||
Self {
|
||||
safetensors: None,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
|
||||
let s: Shape = s.into();
|
||||
match &self.safetensors {
|
||||
None => Tensor::zeros(s, self.dtype, &self.device),
|
||||
Some((routing, safetensors)) => {
|
||||
// Unwrap or 0 just to let the proper error flow.
|
||||
let index = routing.get(tensor_name).unwrap_or(&0);
|
||||
let tensor = safetensors[*index]
|
||||
.tensor(tensor_name, &self.device)?
|
||||
.to_dtype(self.dtype)?;
|
||||
if *tensor.shape() != s {
|
||||
let msg = format!("shape mismatch for {tensor_name}");
|
||||
Err(candle::Error::UnexpectedShape {
|
||||
msg,
|
||||
expected: s,
|
||||
got: tensor.shape().clone(),
|
||||
})?
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn load(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, &format!("{p}.bias"))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self { weight, bias })
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let (bsize, _, _) = x.shape().r3()?;
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
Self { weight, bias, eps }
|
||||
}
|
||||
|
||||
pub fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let (weight, bias) = match (
|
||||
vb.get(size, &format!("{p}.weight")),
|
||||
vb.get(size, &format!("{p}.bias")),
|
||||
) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let (Ok(weight), Ok(bias)) = (
|
||||
vb.get(size, &format!("{p}.gamma")),
|
||||
vb.get(size, &format!("{p}.beta")),
|
||||
) {
|
||||
(weight, bias)
|
||||
} else {
|
||||
return Err(err.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self { weight, bias, eps })
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let dtype = x.dtype();
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.to_dtype(dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Dropout {
|
||||
pr: f64,
|
||||
}
|
||||
|
||||
impl Dropout {
|
||||
pub fn new(pr: f64) -> Self {
|
||||
Self { pr }
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// TODO
|
||||
Ok(x.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedding {
|
||||
embeddings: Tensor,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||
Self {
|
||||
embeddings,
|
||||
hidden_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
Ok(Self::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
let mut final_dims = indexes.dims().to_vec();
|
||||
final_dims.push(self.hidden_size);
|
||||
let indexes = indexes.flatten_all()?;
|
||||
let values = Tensor::embedding(&indexes, &self.embeddings)?;
|
||||
let values = values.reshape(final_dims)?;
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct ConvConfig {
|
||||
pub padding: usize,
|
||||
pub stride: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Conv1D {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
config: ConvConfig,
|
||||
}
|
||||
|
||||
impl Conv1D {
|
||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||
// does not apply to training.
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
||||
pub fn load_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: ConvConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let weight_g = vb.get((out_c, 1, 1), &format!("{p}.weight_g"))?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight_v"))?;
|
||||
let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: ConvConfig,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let weight = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight"))?;
|
||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HiddenAct {
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
|
||||
impl HiddenAct {
|
||||
pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
}
|
||||
}
|
||||
}
|
288
candle-examples/examples/musicgen/t5_model.rs
Normal file
288
candle-examples/examples/musicgen/t5_model.rs
Normal file
@ -0,0 +1,288 @@
|
||||
// T5 Text Encoder
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use crate::nn::{Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::Tensor;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
d_model: usize,
|
||||
d_kv: usize,
|
||||
d_ff: usize,
|
||||
num_layers: usize,
|
||||
num_decoder_layers: Option<usize>,
|
||||
num_heads: usize,
|
||||
relative_attention_num_buckets: usize,
|
||||
relative_attention_max_distance: usize,
|
||||
dropout_rate: f64,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_factor: f64,
|
||||
feed_forward_proj: HiddenAct,
|
||||
is_decoder: bool,
|
||||
is_encoder_decoder: bool,
|
||||
use_cache: bool,
|
||||
pad_token_id: usize,
|
||||
eos_token_id: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 32128,
|
||||
d_model: 512,
|
||||
d_kv: 64,
|
||||
d_ff: 2048,
|
||||
num_layers: 6,
|
||||
num_decoder_layers: None,
|
||||
num_heads: 8,
|
||||
relative_attention_num_buckets: 32,
|
||||
relative_attention_max_distance: 128,
|
||||
dropout_rate: 0.1,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: HiddenAct::Relu,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
use_cache: true,
|
||||
pad_token_id: 0,
|
||||
eos_token_id: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184
|
||||
pub fn musicgen_small() -> Self {
|
||||
Self {
|
||||
d_ff: 3072,
|
||||
d_kv: 64,
|
||||
d_model: 768,
|
||||
dropout_rate: 0.1,
|
||||
eos_token_id: 1,
|
||||
feed_forward_proj: HiddenAct::Relu,
|
||||
initializer_factor: 1.0,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
num_decoder_layers: Some(12),
|
||||
num_heads: 12,
|
||||
num_layers: 12,
|
||||
pad_token_id: 0,
|
||||
relative_attention_max_distance: 128,
|
||||
relative_attention_num_buckets: 32,
|
||||
use_cache: true,
|
||||
vocab_size: 32128,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5LayerNorm {
|
||||
weight: Tensor,
|
||||
variance_epsilon: f64,
|
||||
}
|
||||
|
||||
impl T5LayerNorm {
|
||||
fn load(h: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(h, &format!("{p}.weight"))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
variance_epsilon: eps,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5DenseActDense {
|
||||
wi: Linear,
|
||||
wo: Linear,
|
||||
dropout: Dropout,
|
||||
act: HiddenAct,
|
||||
}
|
||||
|
||||
impl T5DenseActDense {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wi = Linear::load(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
|
||||
let wo = Linear::load(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
wi,
|
||||
wo,
|
||||
dropout,
|
||||
act: HiddenAct::Relu,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5LayerFF {
|
||||
dense_relu_dense: T5DenseActDense,
|
||||
layer_norm: T5LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl T5LayerFF {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
// is_gated_act is not supported.
|
||||
let dense_relu_dense = T5DenseActDense::load(&format!("{p}.DenseReluDense"), vb, cfg)?;
|
||||
let layer_norm = T5LayerNorm::load(
|
||||
cfg.d_model,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.layer_norm"),
|
||||
vb,
|
||||
)?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
dense_relu_dense,
|
||||
layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5Attention {
|
||||
q: Linear,
|
||||
k: Linear,
|
||||
v: Linear,
|
||||
o: Linear,
|
||||
relative_attention_bias: Option<Embedding>,
|
||||
}
|
||||
|
||||
impl T5Attention {
|
||||
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||
let q = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
|
||||
let k = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
|
||||
let v = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
|
||||
let o = Linear::load(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
|
||||
let relative_attention_bias = if h {
|
||||
let emb = Embedding::load(
|
||||
cfg.relative_attention_num_buckets,
|
||||
cfg.num_heads,
|
||||
&format!("{p}.relative_attention_bias"),
|
||||
vb,
|
||||
)?;
|
||||
Some(emb)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
relative_attention_bias,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5LayerSelfAttention {
|
||||
self_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl T5LayerSelfAttention {
|
||||
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let self_attention = T5Attention::load(h, &format!("{p}.SelfAttention"), vb, cfg)?;
|
||||
let layer_norm = T5LayerNorm::load(
|
||||
cfg.d_model,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.layer_norm"),
|
||||
vb,
|
||||
)?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5LayerCrossAttention {}
|
||||
|
||||
impl T5LayerCrossAttention {
|
||||
fn load(_p: &str, _vb: &VarBuilder, _cfg: &Config) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5Block {
|
||||
self_attn: T5LayerSelfAttention,
|
||||
cross_attn: Option<T5LayerCrossAttention>,
|
||||
ff: T5LayerFF,
|
||||
}
|
||||
|
||||
impl T5Block {
|
||||
fn load(
|
||||
has_relative_attention_bias: bool,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let p = &format!("{p}.layer");
|
||||
let self_attn =
|
||||
T5LayerSelfAttention::load(has_relative_attention_bias, &format!("{p}.0"), vb, cfg)?;
|
||||
let cross_attn = if cfg.is_decoder {
|
||||
Some(T5LayerCrossAttention::load(&format!("{p}.1"), vb, cfg)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
|
||||
let ff = T5LayerFF::load(&format!("{p}.{ff_i}"), vb, cfg)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
cross_attn,
|
||||
ff,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5Stack {
|
||||
// TODO: Add embed_tokens if needed (shared embedding layer).
|
||||
block: Vec<T5Block>,
|
||||
final_layer_norm: T5LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl T5Stack {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let block = (0..cfg.num_layers)
|
||||
.map(|i| T5Block::load(i == 0, &format!("{p}.block.{i}"), vb, cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let final_layer_norm = T5LayerNorm::load(
|
||||
cfg.d_model,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.final_layer_norm"),
|
||||
vb,
|
||||
)?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
block,
|
||||
final_layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct T5EncoderModel {
|
||||
shared: Embedding,
|
||||
encoder: T5Stack,
|
||||
}
|
||||
|
||||
impl T5EncoderModel {
|
||||
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = Embedding::load(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
|
||||
let encoder = T5Stack::load(&format!("{p}.encoder"), vb, cfg)?;
|
||||
Ok(Self { shared, encoder })
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user