Fix the musicgen example. (#724)

* Fix the musicgen example.

* Retrieve the weights from the hub.
This commit is contained in:
Laurent Mazare
2023-09-03 15:50:39 +02:00
committed by GitHub
parent f7980e07e0
commit bbec527bb9
5 changed files with 62 additions and 134 deletions

View File

@ -1,10 +1,9 @@
use crate::nn::{
embedding, layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder,
};
use crate::{encodec_model, t5_model};
use anyhow::Result;
use candle::{DType, Device, Tensor, D};
use candle_nn::Module;
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder,
};
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]
@ -16,7 +15,7 @@ pub struct Config {
num_attention_heads: usize,
layerdrop: f64,
use_cache: bool,
activation_function: HiddenAct,
activation_function: Activation,
hidden_size: usize,
dropout: f64,
attention_dropout: f64,
@ -40,7 +39,7 @@ impl Default for Config {
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,
@ -66,7 +65,7 @@ impl Config {
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,
@ -128,7 +127,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
if seq_len > self.weights.dim(0)? {
self.weights = get_embedding(seq_len, self.embedding_dim)?
}
Ok(self.weights.narrow(0, 0, seq_len)?)
self.weights.narrow(0, 0, seq_len)
}
}
@ -149,10 +148,10 @@ impl MusicgenAttention {
let h = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let head_dim = h / num_heads;
let k_proj = linear(h, h, false, vb.pp("k_proj"))?;
let v_proj = linear(h, h, false, vb.pp("v_proj"))?;
let q_proj = linear(h, h, false, vb.pp("q_proj"))?;
let out_proj = linear(h, h, false, vb.pp("out_proj"))?;
let k_proj = linear_no_bias(h, h, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(h, h, vb.pp("v_proj"))?;
let q_proj = linear_no_bias(h, h, vb.pp("q_proj"))?;
let out_proj = linear_no_bias(h, h, vb.pp("out_proj"))?;
Ok(Self {
scaling: 1. / (head_dim as f64).sqrt(),
is_decoder: true,
@ -209,7 +208,7 @@ struct MusicgenDecoderLayer {
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
activation_fn: HiddenAct,
activation_fn: Activation,
}
impl MusicgenDecoderLayer {
@ -219,8 +218,8 @@ impl MusicgenDecoderLayer {
let self_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn = MusicgenAttention::load(vb.pp("encoder_attn"), cfg)?;
let encoder_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear(h, cfg.ffn_dim, false, vb.pp("fc1"))?;
let fc2 = linear(cfg.ffn_dim, h, false, vb.pp("fc2"))?;
let fc1 = linear_no_bias(h, cfg.ffn_dim, vb.pp("fc1"))?;
let fc2 = linear_no_bias(cfg.ffn_dim, h, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(h, 1e-5, vb.pp("final_layer_norm"))?;
Ok(Self {
self_attn,
@ -342,7 +341,7 @@ impl MusicgenForCausalLM {
let h = cfg.hidden_size;
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
let lm_heads = (0..cfg.num_codebooks)
.map(|i| linear(h, cfg.vocab_size, false, vb.pp(&format!("lm_heads.{i}"))))
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}"))))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
decoder,
@ -358,7 +357,7 @@ impl MusicgenForCausalLM {
let lm_logits = self
.lm_heads
.iter()
.map(|h| Ok(h.forward(&hidden_states)?))
.map(|h| h.forward(&hidden_states))
.collect::<Result<Vec<_>>>()?;
let lm_logits = Tensor::stack(&lm_logits, 1)?.reshape((
b_sz * self.num_codebooks,