MusicGen var-store path cleanup. (#132)

This commit is contained in:
Laurent Mazare
2023-07-10 23:13:11 +01:00
committed by GitHub
parent b46c28a2ac
commit 6fc1ab4f0d
5 changed files with 128 additions and 171 deletions

View File

@ -85,8 +85,8 @@ struct T5LayerNorm {
}
impl T5LayerNorm {
fn load(h: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get(h, &format!("{p}.weight"))?;
fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(h, "weight")?;
Ok(Self {
weight,
variance_epsilon: eps,
@ -103,9 +103,9 @@ struct T5DenseActDense {
}
impl T5DenseActDense {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let wi = linear(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
let wo = linear(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wi = linear(cfg.d_model, cfg.d_ff, false, vb.pp("wi"))?;
let wo = linear(cfg.d_ff, cfg.d_model, false, vb.pp("wo"))?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
wi,
@ -124,15 +124,11 @@ struct T5LayerFF {
}
impl T5LayerFF {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
// is_gated_act is not supported.
let dense_relu_dense = T5DenseActDense::load(&format!("{p}.DenseReluDense"), vb, cfg)?;
let layer_norm = T5LayerNorm::load(
cfg.d_model,
cfg.layer_norm_epsilon,
&format!("{p}.layer_norm"),
vb,
)?;
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
dense_relu_dense,
@ -152,18 +148,17 @@ struct T5Attention {
}
impl T5Attention {
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
let q = linear(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
let k = linear(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
let q = linear(cfg.d_model, inner_dim, false, vb.pp("q"))?;
let k = linear(cfg.d_model, inner_dim, false, vb.pp("k"))?;
let v = linear(cfg.d_model, inner_dim, false, vb.pp("v"))?;
let o = linear(inner_dim, cfg.d_model, false, vb.pp("o"))?;
let relative_attention_bias = if h {
let emb = embedding(
cfg.relative_attention_num_buckets,
cfg.num_heads,
&format!("{p}.relative_attention_bias"),
vb,
vb.pp("relative_attention_bias"),
)?;
Some(emb)
} else {
@ -187,14 +182,10 @@ struct T5LayerSelfAttention {
}
impl T5LayerSelfAttention {
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let self_attention = T5Attention::load(h, &format!("{p}.SelfAttention"), vb, cfg)?;
let layer_norm = T5LayerNorm::load(
cfg.d_model,
cfg.layer_norm_epsilon,
&format!("{p}.layer_norm"),
vb,
)?;
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
self_attention,
@ -208,7 +199,7 @@ impl T5LayerSelfAttention {
struct T5LayerCrossAttention {}
impl T5LayerCrossAttention {
fn load(_p: &str, _vb: &VarBuilder, _cfg: &Config) -> Result<Self> {
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
todo!()
}
}
@ -221,22 +212,16 @@ struct T5Block {
}
impl T5Block {
fn load(
has_relative_attention_bias: bool,
p: &str,
vb: &VarBuilder,
cfg: &Config,
) -> Result<Self> {
let p = &format!("{p}.layer");
let self_attn =
T5LayerSelfAttention::load(has_relative_attention_bias, &format!("{p}.0"), vb, cfg)?;
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = vb.pp("layer");
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
let cross_attn = if cfg.is_decoder {
Some(T5LayerCrossAttention::load(&format!("{p}.1"), vb, cfg)?)
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
} else {
None
};
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
let ff = T5LayerFF::load(&format!("{p}.{ff_i}"), vb, cfg)?;
let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
Ok(Self {
self_attn,
cross_attn,
@ -254,15 +239,14 @@ struct T5Stack {
}
impl T5Stack {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let block = (0..cfg.num_layers)
.map(|i| T5Block::load(i == 0, &format!("{p}.block.{i}"), vb, cfg))
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
.collect::<Result<Vec<_>>>()?;
let final_layer_norm = T5LayerNorm::load(
cfg.d_model,
cfg.layer_norm_epsilon,
&format!("{p}.final_layer_norm"),
vb,
vb.pp("final_layer_norm"),
)?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
@ -280,9 +264,9 @@ pub struct T5EncoderModel {
}
impl T5EncoderModel {
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let shared = embedding(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
let encoder = T5Stack::load(&format!("{p}.encoder"), vb, cfg)?;
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let encoder = T5Stack::load(vb.pp("encoder"), cfg)?;
Ok(Self { shared, encoder })
}
}