mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
MusicGen var-store path cleanup. (#132)
This commit is contained in:
@ -85,8 +85,8 @@ struct T5LayerNorm {
|
||||
}
|
||||
|
||||
impl T5LayerNorm {
|
||||
fn load(h: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(h, &format!("{p}.weight"))?;
|
||||
fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(h, "weight")?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
variance_epsilon: eps,
|
||||
@ -103,9 +103,9 @@ struct T5DenseActDense {
|
||||
}
|
||||
|
||||
impl T5DenseActDense {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wi = linear(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
|
||||
let wo = linear(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wi = linear(cfg.d_model, cfg.d_ff, false, vb.pp("wi"))?;
|
||||
let wo = linear(cfg.d_ff, cfg.d_model, false, vb.pp("wo"))?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
wi,
|
||||
@ -124,15 +124,11 @@ struct T5LayerFF {
|
||||
}
|
||||
|
||||
impl T5LayerFF {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
// is_gated_act is not supported.
|
||||
let dense_relu_dense = T5DenseActDense::load(&format!("{p}.DenseReluDense"), vb, cfg)?;
|
||||
let layer_norm = T5LayerNorm::load(
|
||||
cfg.d_model,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.layer_norm"),
|
||||
vb,
|
||||
)?;
|
||||
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
dense_relu_dense,
|
||||
@ -152,18 +148,17 @@ struct T5Attention {
|
||||
}
|
||||
|
||||
impl T5Attention {
|
||||
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||
let q = linear(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
|
||||
let k = linear(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
|
||||
let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
|
||||
let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
|
||||
let q = linear(cfg.d_model, inner_dim, false, vb.pp("q"))?;
|
||||
let k = linear(cfg.d_model, inner_dim, false, vb.pp("k"))?;
|
||||
let v = linear(cfg.d_model, inner_dim, false, vb.pp("v"))?;
|
||||
let o = linear(inner_dim, cfg.d_model, false, vb.pp("o"))?;
|
||||
let relative_attention_bias = if h {
|
||||
let emb = embedding(
|
||||
cfg.relative_attention_num_buckets,
|
||||
cfg.num_heads,
|
||||
&format!("{p}.relative_attention_bias"),
|
||||
vb,
|
||||
vb.pp("relative_attention_bias"),
|
||||
)?;
|
||||
Some(emb)
|
||||
} else {
|
||||
@ -187,14 +182,10 @@ struct T5LayerSelfAttention {
|
||||
}
|
||||
|
||||
impl T5LayerSelfAttention {
|
||||
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let self_attention = T5Attention::load(h, &format!("{p}.SelfAttention"), vb, cfg)?;
|
||||
let layer_norm = T5LayerNorm::load(
|
||||
cfg.d_model,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.layer_norm"),
|
||||
vb,
|
||||
)?;
|
||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
@ -208,7 +199,7 @@ impl T5LayerSelfAttention {
|
||||
struct T5LayerCrossAttention {}
|
||||
|
||||
impl T5LayerCrossAttention {
|
||||
fn load(_p: &str, _vb: &VarBuilder, _cfg: &Config) -> Result<Self> {
|
||||
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
@ -221,22 +212,16 @@ struct T5Block {
|
||||
}
|
||||
|
||||
impl T5Block {
|
||||
fn load(
|
||||
has_relative_attention_bias: bool,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let p = &format!("{p}.layer");
|
||||
let self_attn =
|
||||
T5LayerSelfAttention::load(has_relative_attention_bias, &format!("{p}.0"), vb, cfg)?;
|
||||
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = vb.pp("layer");
|
||||
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
|
||||
let cross_attn = if cfg.is_decoder {
|
||||
Some(T5LayerCrossAttention::load(&format!("{p}.1"), vb, cfg)?)
|
||||
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
|
||||
let ff = T5LayerFF::load(&format!("{p}.{ff_i}"), vb, cfg)?;
|
||||
let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
cross_attn,
|
||||
@ -254,15 +239,14 @@ struct T5Stack {
|
||||
}
|
||||
|
||||
impl T5Stack {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let block = (0..cfg.num_layers)
|
||||
.map(|i| T5Block::load(i == 0, &format!("{p}.block.{i}"), vb, cfg))
|
||||
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let final_layer_norm = T5LayerNorm::load(
|
||||
cfg.d_model,
|
||||
cfg.layer_norm_epsilon,
|
||||
&format!("{p}.final_layer_norm"),
|
||||
vb,
|
||||
vb.pp("final_layer_norm"),
|
||||
)?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
@ -280,9 +264,9 @@ pub struct T5EncoderModel {
|
||||
}
|
||||
|
||||
impl T5EncoderModel {
|
||||
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
|
||||
let encoder = T5Stack::load(&format!("{p}.encoder"), vb, cfg)?;
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let encoder = T5Stack::load(vb.pp("encoder"), cfg)?;
|
||||
Ok(Self { shared, encoder })
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user