From 6fc1ab4f0dc4de62ee106d8c8f2b38156a9b6760 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 10 Jul 2023 23:13:11 +0100 Subject: [PATCH] MusicGen var-store path cleanup. (#132) --- .../examples/musicgen/encodec_model.rs | 122 ++++++++---------- candle-examples/examples/musicgen/main.rs | 2 +- .../examples/musicgen/musicgen_model.rs | 54 ++++---- candle-examples/examples/musicgen/nn.rs | 43 +++--- candle-examples/examples/musicgen/t5_model.rs | 78 +++++------ 5 files changed, 128 insertions(+), 171 deletions(-) diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index 1c071c1a..242c349c 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -127,12 +127,12 @@ struct EncodecEuclideanCodebook { } impl EncodecEuclideanCodebook { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let inited = vb.get(1, &format!("{p}.inited"))?; - let cluster_size = vb.get(cfg.codebook_size, &format!("{p}.cluster_size"))?; + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let inited = vb.get(1, "inited")?; + let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?; let e_shape = (cfg.codebook_size, cfg.codebook_dim()); - let embed = vb.get(e_shape, &format!("{p}.embed"))?; - let embed_avg = vb.get(e_shape, &format!("{p}.embed_avg"))?; + let embed = vb.get(e_shape, "embed")?; + let embed_avg = vb.get(e_shape, "embed_avg")?; Ok(Self { inited, cluster_size, @@ -148,8 +148,8 @@ struct EncodecVectorQuantization { } impl EncodecVectorQuantization { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let codebook = EncodecEuclideanCodebook::load(&format!("{p}.codebook"), vb, cfg)?; + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?; Ok(Self { codebook }) } } @@ -160,10 +160,10 @@ struct EncodecResidualVectorQuantizer { } impl EncodecResidualVectorQuantizer { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let p = format!("{p}.layers"); + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let vb = &vb.pp("layers"); let layers = (0..cfg.num_quantizers()) - .map(|i| EncodecVectorQuantization::load(&format!("{p}.{i}"), vb, cfg)) + .map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg)) .collect::>>()?; Ok(Self { layers }) } @@ -176,14 +176,14 @@ struct EncodecLSTM { } impl EncodecLSTM { - fn load(dim: usize, p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let p = format!("{p}.lstm"); + fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result { + let vb = &vb.pp("lstm"); let mut layers = vec![]; for i in 0..cfg.num_lstm_layers { - let w_hh = vb.get((4 * dim, dim), &format!("{p}.weight_hh_l{i}"))?; - let w_ih = vb.get((4 * dim, dim), &format!("{p}.weight_ih_l{i}"))?; - let b_hh = vb.get(4 * dim, &format!("{p}.bias_hh_l{i}"))?; - let b_ih = vb.get(4 * dim, &format!("{p}.bias_ih_l{i}"))?; + let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?; + let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?; + let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?; + let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?; layers.push((w_hh, w_ih, b_hh, b_ih)) } Ok(Self { layers }) @@ -203,14 +203,13 @@ impl EncodecConvTranspose1d { out_c: usize, k: usize, _stride: usize, - p: &str, - vb: &VarBuilder, + vb: VarBuilder, _cfg: &Config, ) -> Result { - let p = format!("{p}.conv"); - let weight_g = vb.get((in_c, 1, 1), &format!("{p}.weight_g"))?; - let weight_v = vb.get((in_c, out_c, k), &format!("{p}.weight_v"))?; - let bias = vb.get(out_c, &format!("{p}.bias"))?; + let vb = &vb.pp("conv"); + let weight_g = vb.get((in_c, 1, 1), "weight_g")?; + let weight_v = vb.get((in_c, out_c, k), "weight_v")?; + let bias = vb.get(out_c, "bias")?; Ok(Self { weight_g, weight_v, @@ -230,8 +229,7 @@ impl EncodecConv1d { out_c: usize, kernel_size: usize, stride: usize, - p: &str, - vb: &VarBuilder, + vb: VarBuilder, cfg: &Config, ) -> Result { let conv = match cfg.norm_type { @@ -240,16 +238,14 @@ impl EncodecConv1d { out_c, kernel_size, Conv1dConfig { padding: 0, stride }, - &format!("{p}.conv"), - vb, + vb.pp("conv"), )?, NormType::None => conv1d( in_c, out_c, kernel_size, Conv1dConfig { padding: 0, stride }, - &format!("{p}.conv"), - vb, + vb.pp("conv"), )?, }; Ok(Self { conv }) @@ -264,15 +260,9 @@ struct EncodecResnetBlock { } impl EncodecResnetBlock { - fn load( - dim: usize, - dilations: &[usize], - p: &str, - vb: &VarBuilder, - cfg: &Config, - ) -> Result { + fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result { let h = dim / cfg.compress; - let mut layer = Layer::new(format!("{p}.block")); + let mut layer = Layer::new("block"); if dilations.len() != 2 { anyhow::bail!("expected dilations of size 2") } @@ -283,14 +273,13 @@ impl EncodecResnetBlock { h, cfg.residual_kernel_size, 1, - &layer.next_name(), - vb, + vb.pp(&layer.next_name()), cfg, )?; layer.inc(); - let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, &layer.next_name(), vb, cfg)?; + let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, vb.pp(&layer.next_name()), cfg)?; let shortcut = if cfg.use_conv_shortcut { - let conv = EncodecConv1d::load(dim, dim, 1, 1, &format!("{p}.shortcut"), vb, cfg)?; + let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?; Some(conv) } else { None @@ -310,8 +299,11 @@ struct Layer { } impl Layer { - fn new(prefix: String) -> Self { - Self { prefix, cnt: 0 } + fn new(prefix: &str) -> Self { + Self { + prefix: prefix.to_string(), + cnt: 0, + } } fn inc(&mut self) { @@ -334,15 +326,14 @@ struct EncodecEncoder { } impl EncodecEncoder { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let mut layer = Layer::new(format!("{p}.layers")); + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let mut layer = Layer::new("layers"); let init_conv = EncodecConv1d::load( cfg.audio_channels, cfg.num_filters, cfg.kernel_size, 1, - &layer.next_name(), - vb, + vb.pp(&layer.next_name()), cfg, )?; let mut sampling_layers = vec![]; @@ -354,8 +345,7 @@ impl EncodecEncoder { let resnet = EncodecResnetBlock::load( current_scale, &[cfg.dilation_growth_rate.pow(j), 1], - &layer.next_name(), - vb, + vb.pp(&layer.next_name()), cfg, )?; resnets.push(resnet) @@ -366,22 +356,21 @@ impl EncodecEncoder { current_scale * 2, ratio * 2, ratio, - &layer.next_name(), - vb, + vb.pp(&layer.next_name()), cfg, )?; sampling_layers.push((resnets, conv1d)); scaling *= 2; } - let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?; + let final_lstm = + EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?; layer.inc(); // ELU let final_conv = EncodecConv1d::load( cfg.num_filters * scaling, cfg.hidden_size, cfg.last_kernel_size, 1, - &layer.next_name(), - vb, + vb.pp(&layer.next_name()), cfg, )?; Ok(Self { @@ -402,19 +391,19 @@ struct EncodecDecoder { } impl EncodecDecoder { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let mut layer = Layer::new(format!("{p}.layers")); + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let mut layer = Layer::new("layers"); let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32); let init_conv = EncodecConv1d::load( cfg.hidden_size, cfg.num_filters * scaling, cfg.last_kernel_size, 1, - &layer.next_name(), - vb, + vb.pp(&layer.next_name()), cfg, )?; - let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?; + let init_lstm = + EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?; let mut sampling_layers = vec![]; for &ratio in cfg.upsampling_ratios.iter() { let current_scale = scaling * cfg.num_filters; @@ -424,8 +413,7 @@ impl EncodecDecoder { current_scale / 2, ratio * 2, ratio, - &layer.next_name(), - vb, + vb.pp(&layer.next_name()), cfg, )?; let mut resnets = vec![]; @@ -433,8 +421,7 @@ impl EncodecDecoder { let resnet = EncodecResnetBlock::load( current_scale / 2, &[cfg.dilation_growth_rate.pow(j), 1], - &layer.next_name(), - vb, + vb.pp(&layer.next_name()), cfg, )?; resnets.push(resnet) @@ -448,8 +435,7 @@ impl EncodecDecoder { cfg.audio_channels, cfg.last_kernel_size, 1, - &layer.next_name(), - vb, + vb.pp(&layer.next_name()), cfg, )?; Ok(Self { @@ -469,10 +455,10 @@ pub struct EncodecModel { } impl EncodecModel { - pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let encoder = EncodecEncoder::load(&format!("{p}.encoder"), vb, cfg)?; - let decoder = EncodecDecoder::load(&format!("{p}.decoder"), vb, cfg)?; - let quantizer = EncodecResidualVectorQuantizer::load(&format!("{p}.quantizer"), vb, cfg)?; + pub fn load(vb: VarBuilder, cfg: &Config) -> Result { + let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?; + let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?; + let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?; Ok(Self { encoder, decoder, diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index a2f44b04..cb94d3df 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -54,6 +54,6 @@ fn main() -> Result<()> { let model = model.deserialize()?; let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device); let config = GenConfig::small(); - let _model = MusicgenForConditionalGeneration::load(&vb, config)?; + let _model = MusicgenForConditionalGeneration::load(vb, config)?; Ok(()) } diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 6ef15c21..512e35e8 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -111,7 +111,7 @@ struct MusicgenSinusoidalPositionalEmbedding { } impl MusicgenSinusoidalPositionalEmbedding { - fn load(_vb: &VarBuilder, cfg: &Config) -> Result { + fn load(_vb: VarBuilder, cfg: &Config) -> Result { let num_positions = cfg.max_position_embeddings; let embedding_dim = cfg.hidden_size; let weights = get_embedding(num_positions, embedding_dim)?; @@ -144,14 +144,14 @@ struct MusicgenAttention { } impl MusicgenAttention { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let h = cfg.hidden_size; let num_heads = cfg.num_attention_heads; let head_dim = h / num_heads; - let k_proj = linear(h, h, false, &format!("{p}.k_proj"), vb)?; - let v_proj = linear(h, h, false, &format!("{p}.v_proj"), vb)?; - let q_proj = linear(h, h, false, &format!("{p}.q_proj"), vb)?; - let out_proj = linear(h, h, false, &format!("{p}.out_proj"), vb)?; + let k_proj = linear(h, h, false, vb.pp("k_proj"))?; + let v_proj = linear(h, h, false, vb.pp("v_proj"))?; + let q_proj = linear(h, h, false, vb.pp("q_proj"))?; + let out_proj = linear(h, h, false, vb.pp("out_proj"))?; Ok(Self { scaling: 1. / (head_dim as f64).sqrt(), is_decoder: true, @@ -212,16 +212,15 @@ struct MusicgenDecoderLayer { } impl MusicgenDecoderLayer { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let h = cfg.hidden_size; - let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?; - let self_attn_layer_norm = layer_norm(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?; - let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?; - let encoder_attn_layer_norm = - layer_norm(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?; - let fc1 = linear(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?; - let fc2 = linear(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?; - let final_layer_norm = layer_norm(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?; + let self_attn = MusicgenAttention::load(vb.pp("self_attn"), cfg)?; + let self_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("self_attn_layer_norm"))?; + let encoder_attn = MusicgenAttention::load(vb.pp("encoder_attn"), cfg)?; + let encoder_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("encoder_attn_layer_norm"))?; + let fc1 = linear(h, cfg.ffn_dim, false, vb.pp("fc1"))?; + let fc2 = linear(cfg.ffn_dim, h, false, vb.pp("fc2"))?; + let final_layer_norm = layer_norm(h, 1e-5, vb.pp("final_layer_norm"))?; Ok(Self { self_attn, self_attn_layer_norm, @@ -276,7 +275,7 @@ struct MusicgenDecoder { } impl MusicgenDecoder { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let h = cfg.hidden_size; let embed_scale = if cfg.scale_embedding { (h as f64).sqrt() @@ -285,13 +284,13 @@ impl MusicgenDecoder { }; let embed_dim = cfg.vocab_size + 1; let embed_tokens = (0..cfg.num_codebooks) - .map(|i| embedding(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb)) + .map(|i| embedding(embed_dim, h, vb.pp(&format!("embed_tokens.{i}")))) .collect::>>()?; - let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb, cfg)?; + let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?; let layers = (0..cfg.num_hidden_layers) - .map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg)) + .map(|i| MusicgenDecoderLayer::load(vb.pp(&format!("layers.{i}")), cfg)) .collect::>>()?; - let layer_norm = layer_norm(h, 1e-5, &format!("{p}.layer_norm"), vb)?; + let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?; Ok(Self { embed_tokens, embed_positions, @@ -338,11 +337,11 @@ pub struct MusicgenForCausalLM { } impl MusicgenForCausalLM { - pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let h = cfg.hidden_size; - let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?; + let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?; let lm_heads = (0..cfg.num_codebooks) - .map(|i| linear(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb)) + .map(|i| linear(h, cfg.vocab_size, false, vb.pp(&format!("lm_heads.{i}")))) .collect::>>()?; Ok(Self { decoder, @@ -399,10 +398,11 @@ impl MusicgenForConditionalGeneration { &self.cfg } - pub fn load(vb: &VarBuilder, cfg: GenConfig) -> Result { - let text_encoder = t5_model::T5EncoderModel::load("text_encoder", vb, &cfg.t5)?; - let audio_encoder = encodec_model::EncodecModel::load("audio_encoder", vb, &cfg.encodec)?; - let decoder = MusicgenForCausalLM::load("decoder", vb, &cfg.musicgen)?; + pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result { + let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?; + let audio_encoder = + encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?; + let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?; Ok(Self { text_encoder, audio_encoder, diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs index 235de1ec..81643466 100644 --- a/candle-examples/examples/musicgen/nn.rs +++ b/candle-examples/examples/musicgen/nn.rs @@ -8,10 +8,10 @@ const MAX_SEQ_LEN: usize = 5000; pub type VarBuilder<'a> = candle_nn::VarBuilder<'a>; pub type Linear = candle_nn::Linear; -pub fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get((size2, size1), &format!("{p}.weight"))?; +pub fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; let bias = if bias { - Some(vb.get(size2, &format!("{p}.bias"))?) + Some(vb.get(size2, "bias")?) } else { None }; @@ -20,17 +20,11 @@ pub fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) pub type LayerNorm = candle_nn::LayerNorm; -pub fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { - let (weight, bias) = match ( - vb.get(size, &format!("{p}.weight")), - vb.get(size, &format!("{p}.bias")), - ) { +pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { (Ok(weight), Ok(bias)) => (weight, bias), (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = ( - vb.get(size, &format!("{p}.gamma")), - vb.get(size, &format!("{p}.beta")), - ) { + if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { (weight, bias) } else { return Err(err.into()); @@ -58,13 +52,8 @@ impl Dropout { pub type Embedding = candle_nn::Embedding; -pub fn embedding( - vocab_size: usize, - hidden_size: usize, - p: &str, - vb: &VarBuilder, -) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; +pub fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) } @@ -79,14 +68,13 @@ pub fn conv1d_weight_norm( out_c: usize, kernel_size: usize, config: Conv1dConfig, - p: &str, - vb: &VarBuilder, + vb: VarBuilder, ) -> Result { - let weight_g = vb.get((out_c, 1, 1), &format!("{p}.weight_g"))?; - let weight_v = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight_v"))?; + let weight_g = vb.get((out_c, 1, 1), "weight_g")?; + let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?; let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?; let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?; - let bias = vb.get(out_c, &format!("{p}.bias"))?; + let bias = vb.get(out_c, "bias")?; Ok(Conv1d::new(weight, Some(bias), config)) } @@ -95,11 +83,10 @@ pub fn conv1d( out_c: usize, kernel_size: usize, config: Conv1dConfig, - p: &str, - vb: &VarBuilder, + vb: VarBuilder, ) -> Result { - let weight = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight"))?; - let bias = vb.get(out_c, &format!("{p}.bias"))?; + let weight = vb.get((out_c, in_c, kernel_size), "weight")?; + let bias = vb.get(out_c, "bias")?; Ok(Conv1d::new(weight, Some(bias), config)) } diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index c904d67c..ad784224 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -85,8 +85,8 @@ struct T5LayerNorm { } impl T5LayerNorm { - fn load(h: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get(h, &format!("{p}.weight"))?; + fn load(h: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get(h, "weight")?; Ok(Self { weight, variance_epsilon: eps, @@ -103,9 +103,9 @@ struct T5DenseActDense { } impl T5DenseActDense { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let wi = linear(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?; - let wo = linear(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?; + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let wi = linear(cfg.d_model, cfg.d_ff, false, vb.pp("wi"))?; + let wo = linear(cfg.d_ff, cfg.d_model, false, vb.pp("wo"))?; let dropout = Dropout::new(cfg.dropout_rate); Ok(Self { wi, @@ -124,15 +124,11 @@ struct T5LayerFF { } impl T5LayerFF { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { // is_gated_act is not supported. - let dense_relu_dense = T5DenseActDense::load(&format!("{p}.DenseReluDense"), vb, cfg)?; - let layer_norm = T5LayerNorm::load( - cfg.d_model, - cfg.layer_norm_epsilon, - &format!("{p}.layer_norm"), - vb, - )?; + let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; let dropout = Dropout::new(cfg.dropout_rate); Ok(Self { dense_relu_dense, @@ -152,18 +148,17 @@ struct T5Attention { } impl T5Attention { - fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result { let inner_dim = cfg.num_heads * cfg.d_kv; - let q = linear(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?; - let k = linear(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?; - let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?; - let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?; + let q = linear(cfg.d_model, inner_dim, false, vb.pp("q"))?; + let k = linear(cfg.d_model, inner_dim, false, vb.pp("k"))?; + let v = linear(cfg.d_model, inner_dim, false, vb.pp("v"))?; + let o = linear(inner_dim, cfg.d_model, false, vb.pp("o"))?; let relative_attention_bias = if h { let emb = embedding( cfg.relative_attention_num_buckets, cfg.num_heads, - &format!("{p}.relative_attention_bias"), - vb, + vb.pp("relative_attention_bias"), )?; Some(emb) } else { @@ -187,14 +182,10 @@ struct T5LayerSelfAttention { } impl T5LayerSelfAttention { - fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let self_attention = T5Attention::load(h, &format!("{p}.SelfAttention"), vb, cfg)?; - let layer_norm = T5LayerNorm::load( - cfg.d_model, - cfg.layer_norm_epsilon, - &format!("{p}.layer_norm"), - vb, - )?; + fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result { + let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; let dropout = Dropout::new(cfg.dropout_rate); Ok(Self { self_attention, @@ -208,7 +199,7 @@ impl T5LayerSelfAttention { struct T5LayerCrossAttention {} impl T5LayerCrossAttention { - fn load(_p: &str, _vb: &VarBuilder, _cfg: &Config) -> Result { + fn load(_vb: VarBuilder, _cfg: &Config) -> Result { todo!() } } @@ -221,22 +212,16 @@ struct T5Block { } impl T5Block { - fn load( - has_relative_attention_bias: bool, - p: &str, - vb: &VarBuilder, - cfg: &Config, - ) -> Result { - let p = &format!("{p}.layer"); - let self_attn = - T5LayerSelfAttention::load(has_relative_attention_bias, &format!("{p}.0"), vb, cfg)?; + fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result { + let vb = vb.pp("layer"); + let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?; let cross_attn = if cfg.is_decoder { - Some(T5LayerCrossAttention::load(&format!("{p}.1"), vb, cfg)?) + Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?) } else { None }; let ff_i = if cross_attn.is_some() { 2 } else { 1 }; - let ff = T5LayerFF::load(&format!("{p}.{ff_i}"), vb, cfg)?; + let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?; Ok(Self { self_attn, cross_attn, @@ -254,15 +239,14 @@ struct T5Stack { } impl T5Stack { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let block = (0..cfg.num_layers) - .map(|i| T5Block::load(i == 0, &format!("{p}.block.{i}"), vb, cfg)) + .map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg)) .collect::>>()?; let final_layer_norm = T5LayerNorm::load( cfg.d_model, cfg.layer_norm_epsilon, - &format!("{p}.final_layer_norm"), - vb, + vb.pp("final_layer_norm"), )?; let dropout = Dropout::new(cfg.dropout_rate); Ok(Self { @@ -280,9 +264,9 @@ pub struct T5EncoderModel { } impl T5EncoderModel { - pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let shared = embedding(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?; - let encoder = T5Stack::load(&format!("{p}.encoder"), vb, cfg)?; + pub fn load(vb: VarBuilder, cfg: &Config) -> Result { + let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let encoder = T5Stack::load(vb.pp("encoder"), cfg)?; Ok(Self { shared, encoder }) } }