mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
MusicGen var-store path cleanup. (#132)
This commit is contained in:
@ -111,7 +111,7 @@ struct MusicgenSinusoidalPositionalEmbedding {
|
||||
}
|
||||
|
||||
impl MusicgenSinusoidalPositionalEmbedding {
|
||||
fn load(_vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(_vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let num_positions = cfg.max_position_embeddings;
|
||||
let embedding_dim = cfg.hidden_size;
|
||||
let weights = get_embedding(num_positions, embedding_dim)?;
|
||||
@ -144,14 +144,14 @@ struct MusicgenAttention {
|
||||
}
|
||||
|
||||
impl MusicgenAttention {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let head_dim = h / num_heads;
|
||||
let k_proj = linear(h, h, false, &format!("{p}.k_proj"), vb)?;
|
||||
let v_proj = linear(h, h, false, &format!("{p}.v_proj"), vb)?;
|
||||
let q_proj = linear(h, h, false, &format!("{p}.q_proj"), vb)?;
|
||||
let out_proj = linear(h, h, false, &format!("{p}.out_proj"), vb)?;
|
||||
let k_proj = linear(h, h, false, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(h, h, false, vb.pp("v_proj"))?;
|
||||
let q_proj = linear(h, h, false, vb.pp("q_proj"))?;
|
||||
let out_proj = linear(h, h, false, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
scaling: 1. / (head_dim as f64).sqrt(),
|
||||
is_decoder: true,
|
||||
@ -212,16 +212,15 @@ struct MusicgenDecoderLayer {
|
||||
}
|
||||
|
||||
impl MusicgenDecoderLayer {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?;
|
||||
let self_attn_layer_norm = layer_norm(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||
let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?;
|
||||
let encoder_attn_layer_norm =
|
||||
layer_norm(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||
let fc1 = linear(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?;
|
||||
let fc2 = linear(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?;
|
||||
let final_layer_norm = layer_norm(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?;
|
||||
let self_attn = MusicgenAttention::load(vb.pp("self_attn"), cfg)?;
|
||||
let self_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let encoder_attn = MusicgenAttention::load(vb.pp("encoder_attn"), cfg)?;
|
||||
let encoder_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||
let fc1 = linear(h, cfg.ffn_dim, false, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.ffn_dim, h, false, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(h, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
self_attn_layer_norm,
|
||||
@ -276,7 +275,7 @@ struct MusicgenDecoder {
|
||||
}
|
||||
|
||||
impl MusicgenDecoder {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let embed_scale = if cfg.scale_embedding {
|
||||
(h as f64).sqrt()
|
||||
@ -285,13 +284,13 @@ impl MusicgenDecoder {
|
||||
};
|
||||
let embed_dim = cfg.vocab_size + 1;
|
||||
let embed_tokens = (0..cfg.num_codebooks)
|
||||
.map(|i| embedding(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb))
|
||||
.map(|i| embedding(embed_dim, h, vb.pp(&format!("embed_tokens.{i}"))))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb, cfg)?;
|
||||
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?;
|
||||
let layers = (0..cfg.num_hidden_layers)
|
||||
.map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg))
|
||||
.map(|i| MusicgenDecoderLayer::load(vb.pp(&format!("layers.{i}")), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let layer_norm = layer_norm(h, 1e-5, &format!("{p}.layer_norm"), vb)?;
|
||||
let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
@ -338,11 +337,11 @@ pub struct MusicgenForCausalLM {
|
||||
}
|
||||
|
||||
impl MusicgenForCausalLM {
|
||||
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?;
|
||||
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
|
||||
let lm_heads = (0..cfg.num_codebooks)
|
||||
.map(|i| linear(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb))
|
||||
.map(|i| linear(h, cfg.vocab_size, false, vb.pp(&format!("lm_heads.{i}"))))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
decoder,
|
||||
@ -399,10 +398,11 @@ impl MusicgenForConditionalGeneration {
|
||||
&self.cfg
|
||||
}
|
||||
|
||||
pub fn load(vb: &VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||
let text_encoder = t5_model::T5EncoderModel::load("text_encoder", vb, &cfg.t5)?;
|
||||
let audio_encoder = encodec_model::EncodecModel::load("audio_encoder", vb, &cfg.encodec)?;
|
||||
let decoder = MusicgenForCausalLM::load("decoder", vb, &cfg.musicgen)?;
|
||||
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||
let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||
let audio_encoder =
|
||||
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
||||
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
||||
Ok(Self {
|
||||
text_encoder,
|
||||
audio_encoder,
|
||||
|
Reference in New Issue
Block a user