From 0e9d3afd775c4ca466444e26908fd3f9291978b7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 10 Jul 2023 23:22:58 +0100 Subject: [PATCH] Simplify the var-builder layer setup. (#133) --- .../examples/musicgen/encodec_model.rs | 60 ++++++++----------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index 242c349c..ed8a66b7 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -262,22 +262,16 @@ struct EncodecResnetBlock { impl EncodecResnetBlock { fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result { let h = dim / cfg.compress; - let mut layer = Layer::new("block"); + let mut layer = Layer::new(vb.pp("block")); if dilations.len() != 2 { anyhow::bail!("expected dilations of size 2") } // TODO: Apply dilations! layer.inc(); - let block_conv1 = EncodecConv1d::load( - dim, - h, - cfg.residual_kernel_size, - 1, - vb.pp(&layer.next_name()), - cfg, - )?; + let block_conv1 = + EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?; layer.inc(); - let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, vb.pp(&layer.next_name()), cfg)?; + let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?; let shortcut = if cfg.use_conv_shortcut { let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?; Some(conv) @@ -292,28 +286,24 @@ impl EncodecResnetBlock { } } -#[derive(Debug)] -struct Layer { - prefix: String, +struct Layer<'a> { + vb: VarBuilder<'a>, cnt: usize, } -impl Layer { - fn new(prefix: &str) -> Self { - Self { - prefix: prefix.to_string(), - cnt: 0, - } +impl<'a> Layer<'a> { + fn new(vb: VarBuilder<'a>) -> Self { + Self { vb, cnt: 0 } } fn inc(&mut self) { self.cnt += 1; } - fn next_name(&mut self) -> String { - let name = format!("{}.{}", self.prefix, self.cnt); + fn next(&mut self) -> VarBuilder<'a> { + let vb = self.vb.pp(&self.cnt.to_string()); self.cnt += 1; - name + vb } } @@ -327,13 +317,13 @@ struct EncodecEncoder { impl EncodecEncoder { fn load(vb: VarBuilder, cfg: &Config) -> Result { - let mut layer = Layer::new("layers"); + let mut layer = Layer::new(vb.pp("layers")); let init_conv = EncodecConv1d::load( cfg.audio_channels, cfg.num_filters, cfg.kernel_size, 1, - vb.pp(&layer.next_name()), + layer.next(), cfg, )?; let mut sampling_layers = vec![]; @@ -345,7 +335,7 @@ impl EncodecEncoder { let resnet = EncodecResnetBlock::load( current_scale, &[cfg.dilation_growth_rate.pow(j), 1], - vb.pp(&layer.next_name()), + layer.next(), cfg, )?; resnets.push(resnet) @@ -356,21 +346,20 @@ impl EncodecEncoder { current_scale * 2, ratio * 2, ratio, - vb.pp(&layer.next_name()), + layer.next(), cfg, )?; sampling_layers.push((resnets, conv1d)); scaling *= 2; } - let final_lstm = - EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?; + let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?; layer.inc(); // ELU let final_conv = EncodecConv1d::load( cfg.num_filters * scaling, cfg.hidden_size, cfg.last_kernel_size, 1, - vb.pp(&layer.next_name()), + layer.next(), cfg, )?; Ok(Self { @@ -392,18 +381,17 @@ struct EncodecDecoder { impl EncodecDecoder { fn load(vb: VarBuilder, cfg: &Config) -> Result { - let mut layer = Layer::new("layers"); + let mut layer = Layer::new(vb.pp("layers")); let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32); let init_conv = EncodecConv1d::load( cfg.hidden_size, cfg.num_filters * scaling, cfg.last_kernel_size, 1, - vb.pp(&layer.next_name()), + layer.next(), cfg, )?; - let init_lstm = - EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?; + let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?; let mut sampling_layers = vec![]; for &ratio in cfg.upsampling_ratios.iter() { let current_scale = scaling * cfg.num_filters; @@ -413,7 +401,7 @@ impl EncodecDecoder { current_scale / 2, ratio * 2, ratio, - vb.pp(&layer.next_name()), + layer.next(), cfg, )?; let mut resnets = vec![]; @@ -421,7 +409,7 @@ impl EncodecDecoder { let resnet = EncodecResnetBlock::load( current_scale / 2, &[cfg.dilation_growth_rate.pow(j), 1], - vb.pp(&layer.next_name()), + layer.next(), cfg, )?; resnets.push(resnet) @@ -435,7 +423,7 @@ impl EncodecDecoder { cfg.audio_channels, cfg.last_kernel_size, 1, - vb.pp(&layer.next_name()), + layer.next(), cfg, )?; Ok(Self {