mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Fix the musicgen example. (#724)
* Fix the musicgen example. * Retrieve the weights from the hub.
This commit is contained in:
@ -1,10 +1,9 @@
|
||||
use crate::nn::{
|
||||
embedding, layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder,
|
||||
};
|
||||
use crate::{encodec_model, t5_model};
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle_nn::Module;
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||
VarBuilder,
|
||||
};
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -16,7 +15,7 @@ pub struct Config {
|
||||
num_attention_heads: usize,
|
||||
layerdrop: f64,
|
||||
use_cache: bool,
|
||||
activation_function: HiddenAct,
|
||||
activation_function: Activation,
|
||||
hidden_size: usize,
|
||||
dropout: f64,
|
||||
attention_dropout: f64,
|
||||
@ -40,7 +39,7 @@ impl Default for Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
@ -66,7 +65,7 @@ impl Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
@ -128,7 +127,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
|
||||
if seq_len > self.weights.dim(0)? {
|
||||
self.weights = get_embedding(seq_len, self.embedding_dim)?
|
||||
}
|
||||
Ok(self.weights.narrow(0, 0, seq_len)?)
|
||||
self.weights.narrow(0, 0, seq_len)
|
||||
}
|
||||
}
|
||||
|
||||
@ -149,10 +148,10 @@ impl MusicgenAttention {
|
||||
let h = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let head_dim = h / num_heads;
|
||||
let k_proj = linear(h, h, false, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(h, h, false, vb.pp("v_proj"))?;
|
||||
let q_proj = linear(h, h, false, vb.pp("q_proj"))?;
|
||||
let out_proj = linear(h, h, false, vb.pp("out_proj"))?;
|
||||
let k_proj = linear_no_bias(h, h, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(h, h, vb.pp("v_proj"))?;
|
||||
let q_proj = linear_no_bias(h, h, vb.pp("q_proj"))?;
|
||||
let out_proj = linear_no_bias(h, h, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
scaling: 1. / (head_dim as f64).sqrt(),
|
||||
is_decoder: true,
|
||||
@ -209,7 +208,7 @@ struct MusicgenDecoderLayer {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
final_layer_norm: LayerNorm,
|
||||
activation_fn: HiddenAct,
|
||||
activation_fn: Activation,
|
||||
}
|
||||
|
||||
impl MusicgenDecoderLayer {
|
||||
@ -219,8 +218,8 @@ impl MusicgenDecoderLayer {
|
||||
let self_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let encoder_attn = MusicgenAttention::load(vb.pp("encoder_attn"), cfg)?;
|
||||
let encoder_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||
let fc1 = linear(h, cfg.ffn_dim, false, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.ffn_dim, h, false, vb.pp("fc2"))?;
|
||||
let fc1 = linear_no_bias(h, cfg.ffn_dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear_no_bias(cfg.ffn_dim, h, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(h, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
@ -342,7 +341,7 @@ impl MusicgenForCausalLM {
|
||||
let h = cfg.hidden_size;
|
||||
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
|
||||
let lm_heads = (0..cfg.num_codebooks)
|
||||
.map(|i| linear(h, cfg.vocab_size, false, vb.pp(&format!("lm_heads.{i}"))))
|
||||
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}"))))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
decoder,
|
||||
@ -358,7 +357,7 @@ impl MusicgenForCausalLM {
|
||||
let lm_logits = self
|
||||
.lm_heads
|
||||
.iter()
|
||||
.map(|h| Ok(h.forward(&hidden_states)?))
|
||||
.map(|h| h.forward(&hidden_states))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let lm_logits = Tensor::stack(&lm_logits, 1)?.reshape((
|
||||
b_sz * self.num_codebooks,
|
||||
|
Reference in New Issue
Block a user