MusicGen var-store path cleanup. (#132)

This commit is contained in:
Laurent Mazare
2023-07-10 23:13:11 +01:00
committed by GitHub
parent b46c28a2ac
commit 6fc1ab4f0d
5 changed files with 128 additions and 171 deletions

View File

@ -111,7 +111,7 @@ struct MusicgenSinusoidalPositionalEmbedding {
}
impl MusicgenSinusoidalPositionalEmbedding {
fn load(_vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(_vb: VarBuilder, cfg: &Config) -> Result<Self> {
let num_positions = cfg.max_position_embeddings;
let embedding_dim = cfg.hidden_size;
let weights = get_embedding(num_positions, embedding_dim)?;
@ -144,14 +144,14 @@ struct MusicgenAttention {
}
impl MusicgenAttention {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let head_dim = h / num_heads;
let k_proj = linear(h, h, false, &format!("{p}.k_proj"), vb)?;
let v_proj = linear(h, h, false, &format!("{p}.v_proj"), vb)?;
let q_proj = linear(h, h, false, &format!("{p}.q_proj"), vb)?;
let out_proj = linear(h, h, false, &format!("{p}.out_proj"), vb)?;
let k_proj = linear(h, h, false, vb.pp("k_proj"))?;
let v_proj = linear(h, h, false, vb.pp("v_proj"))?;
let q_proj = linear(h, h, false, vb.pp("q_proj"))?;
let out_proj = linear(h, h, false, vb.pp("out_proj"))?;
Ok(Self {
scaling: 1. / (head_dim as f64).sqrt(),
is_decoder: true,
@ -212,16 +212,15 @@ struct MusicgenDecoderLayer {
}
impl MusicgenDecoderLayer {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?;
let self_attn_layer_norm = layer_norm(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?;
let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?;
let encoder_attn_layer_norm =
layer_norm(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?;
let fc1 = linear(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?;
let fc2 = linear(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?;
let final_layer_norm = layer_norm(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?;
let self_attn = MusicgenAttention::load(vb.pp("self_attn"), cfg)?;
let self_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn = MusicgenAttention::load(vb.pp("encoder_attn"), cfg)?;
let encoder_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear(h, cfg.ffn_dim, false, vb.pp("fc1"))?;
let fc2 = linear(cfg.ffn_dim, h, false, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(h, 1e-5, vb.pp("final_layer_norm"))?;
Ok(Self {
self_attn,
self_attn_layer_norm,
@ -276,7 +275,7 @@ struct MusicgenDecoder {
}
impl MusicgenDecoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let embed_scale = if cfg.scale_embedding {
(h as f64).sqrt()
@ -285,13 +284,13 @@ impl MusicgenDecoder {
};
let embed_dim = cfg.vocab_size + 1;
let embed_tokens = (0..cfg.num_codebooks)
.map(|i| embedding(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb))
.map(|i| embedding(embed_dim, h, vb.pp(&format!("embed_tokens.{i}"))))
.collect::<Result<Vec<_>>>()?;
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb, cfg)?;
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?;
let layers = (0..cfg.num_hidden_layers)
.map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg))
.map(|i| MusicgenDecoderLayer::load(vb.pp(&format!("layers.{i}")), cfg))
.collect::<Result<Vec<_>>>()?;
let layer_norm = layer_norm(h, 1e-5, &format!("{p}.layer_norm"), vb)?;
let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?;
Ok(Self {
embed_tokens,
embed_positions,
@ -338,11 +337,11 @@ pub struct MusicgenForCausalLM {
}
impl MusicgenForCausalLM {
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?;
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
let lm_heads = (0..cfg.num_codebooks)
.map(|i| linear(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb))
.map(|i| linear(h, cfg.vocab_size, false, vb.pp(&format!("lm_heads.{i}"))))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
decoder,
@ -399,10 +398,11 @@ impl MusicgenForConditionalGeneration {
&self.cfg
}
pub fn load(vb: &VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5_model::T5EncoderModel::load("text_encoder", vb, &cfg.t5)?;
let audio_encoder = encodec_model::EncodecModel::load("audio_encoder", vb, &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load("decoder", vb, &cfg.musicgen)?;
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let audio_encoder =
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
Ok(Self {
text_encoder,
audio_encoder,