Simplify the var-builder layer setup. (#133)

This commit is contained in:
Laurent Mazare
2023-07-10 23:22:58 +01:00
committed by GitHub
parent 6fc1ab4f0d
commit 0e9d3afd77

View File

@ -262,22 +262,16 @@ struct EncodecResnetBlock {
impl EncodecResnetBlock {
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = dim / cfg.compress;
let mut layer = Layer::new("block");
let mut layer = Layer::new(vb.pp("block"));
if dilations.len() != 2 {
anyhow::bail!("expected dilations of size 2")
}
// TODO: Apply dilations!
layer.inc();
let block_conv1 = EncodecConv1d::load(
dim,
h,
cfg.residual_kernel_size,
1,
vb.pp(&layer.next_name()),
cfg,
)?;
let block_conv1 =
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
layer.inc();
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, vb.pp(&layer.next_name()), cfg)?;
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
let shortcut = if cfg.use_conv_shortcut {
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
Some(conv)
@ -292,28 +286,24 @@ impl EncodecResnetBlock {
}
}
#[derive(Debug)]
struct Layer {
prefix: String,
struct Layer<'a> {
vb: VarBuilder<'a>,
cnt: usize,
}
impl Layer {
fn new(prefix: &str) -> Self {
Self {
prefix: prefix.to_string(),
cnt: 0,
}
impl<'a> Layer<'a> {
fn new(vb: VarBuilder<'a>) -> Self {
Self { vb, cnt: 0 }
}
fn inc(&mut self) {
self.cnt += 1;
}
fn next_name(&mut self) -> String {
let name = format!("{}.{}", self.prefix, self.cnt);
fn next(&mut self) -> VarBuilder<'a> {
let vb = self.vb.pp(&self.cnt.to_string());
self.cnt += 1;
name
vb
}
}
@ -327,13 +317,13 @@ struct EncodecEncoder {
impl EncodecEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new("layers");
let mut layer = Layer::new(vb.pp("layers"));
let init_conv = EncodecConv1d::load(
cfg.audio_channels,
cfg.num_filters,
cfg.kernel_size,
1,
vb.pp(&layer.next_name()),
layer.next(),
cfg,
)?;
let mut sampling_layers = vec![];
@ -345,7 +335,7 @@ impl EncodecEncoder {
let resnet = EncodecResnetBlock::load(
current_scale,
&[cfg.dilation_growth_rate.pow(j), 1],
vb.pp(&layer.next_name()),
layer.next(),
cfg,
)?;
resnets.push(resnet)
@ -356,21 +346,20 @@ impl EncodecEncoder {
current_scale * 2,
ratio * 2,
ratio,
vb.pp(&layer.next_name()),
layer.next(),
cfg,
)?;
sampling_layers.push((resnets, conv1d));
scaling *= 2;
}
let final_lstm =
EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?;
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
layer.inc(); // ELU
let final_conv = EncodecConv1d::load(
cfg.num_filters * scaling,
cfg.hidden_size,
cfg.last_kernel_size,
1,
vb.pp(&layer.next_name()),
layer.next(),
cfg,
)?;
Ok(Self {
@ -392,18 +381,17 @@ struct EncodecDecoder {
impl EncodecDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new("layers");
let mut layer = Layer::new(vb.pp("layers"));
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
let init_conv = EncodecConv1d::load(
cfg.hidden_size,
cfg.num_filters * scaling,
cfg.last_kernel_size,
1,
vb.pp(&layer.next_name()),
layer.next(),
cfg,
)?;
let init_lstm =
EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?;
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
let mut sampling_layers = vec![];
for &ratio in cfg.upsampling_ratios.iter() {
let current_scale = scaling * cfg.num_filters;
@ -413,7 +401,7 @@ impl EncodecDecoder {
current_scale / 2,
ratio * 2,
ratio,
vb.pp(&layer.next_name()),
layer.next(),
cfg,
)?;
let mut resnets = vec![];
@ -421,7 +409,7 @@ impl EncodecDecoder {
let resnet = EncodecResnetBlock::load(
current_scale / 2,
&[cfg.dilation_growth_rate.pow(j), 1],
vb.pp(&layer.next_name()),
layer.next(),
cfg,
)?;
resnets.push(resnet)
@ -435,7 +423,7 @@ impl EncodecDecoder {
cfg.audio_channels,
cfg.last_kernel_size,
1,
vb.pp(&layer.next_name()),
layer.next(),
cfg,
)?;
Ok(Self {